Documentation
¶
Overview ¶
Package inference provides a high-level API for loading GGUF models and running text generation, chat, embedding, and speculative decoding with minimal boilerplate.
Loading Models ¶
There are two entry points for loading a model:
- Load resolves a model by name or HuggingFace repo ID, pulling it from the registry if not already cached, and returns a ready-to-use Model.
- LoadFile loads a model directly from a local GGUF file path.
Both accept functional Option values to configure the compute device, cache directory, sequence length, and other parameters:
m, err := inference.Load("gemma-3-1b-q4",
inference.WithDevice("cuda"),
inference.WithMaxSeqLen(4096),
)
if err != nil {
log.Fatal(err)
}
defer m.Close()
text, err := m.Generate(ctx, "Explain gradient descent briefly.",
inference.WithMaxTokens(256),
inference.WithTemperature(0.7),
)
Model Methods ¶
A loaded Model exposes several generation methods:
- Model.Generate produces text from a prompt and returns the full result.
- Model.GenerateStream delivers tokens incrementally via a callback.
- Model.GenerateBatch processes multiple prompts concurrently.
- Model.Chat formats a slice of Message values using the model's chat template and generates a Response with token usage statistics.
- Model.Embed returns an L2-normalized embedding vector for a text input by mean-pooling the model's token embedding table.
- Model.SpeculativeGenerate runs speculative decoding with a smaller draft model to accelerate generation from a larger target model.
Load Options ¶
The following Option functions configure model loading:
- WithDevice — compute device: "cpu", "cuda", "cuda:N", "rocm", "opencl"
- WithCacheDir — local directory for cached model files
- WithMaxSeqLen — override the model's default maximum sequence length
- WithRegistry — supply a custom model registry
- WithBackend — select "tensorrt" for TensorRT-optimized inference
- WithPrecision — set TensorRT compute precision ("fp16")
- WithDType — set GPU compute precision ("fp16", "fp8")
- WithKVDtype — set KV cache storage precision ("fp16")
- WithMmap — enable memory-mapped model loading on unix
Generate Options ¶
The following GenerateOption functions configure sampling for generation methods:
- WithTemperature — sampling temperature (higher = more random)
- WithTopK — top-K sampling cutoff
- WithTopP — nucleus (top-P) sampling threshold
- WithMaxTokens — maximum number of tokens to generate
- WithRepetitionPenalty — penalize repeated tokens
- WithStopStrings — strings that terminate generation
- WithGrammar — constrained decoding via a grammar state machine
Model Aliases ¶
Short aliases such as "gemma-3-1b-q4" and "llama-3-8b-q4" map to full HuggingFace repository IDs. Use ResolveAlias to look up the mapping and RegisterAlias to add custom aliases.
Related Packages ¶
For lower-level control over text generation, KV caching, and sampling, see the github.com/zerfoo/zerfoo/generate package. For an OpenAI-compatible HTTP server built on top of this package, see github.com/zerfoo/zerfoo/serve.
Index ¶
- func ConvertGraphToTRT(g *graph.Graph[float32], workspaceBytes int, fp16 bool, ...) (*trtConversionResult, error)
- func LoadTRTEngine(key string) ([]byte, error)
- func RegisterAlias(shortName, repoID string)
- func ResolveAlias(name string) string
- func SaveTRTEngine(key string, data []byte) error
- func TRTCacheKey(modelID, precision string) (string, error)
- type ArchConfigRegistry
- type ConfigParser
- type ConstantValueGetter
- type DTypeSetter
- type DynamicShapeConfig
- type GGUFModel
- type GenerateOption
- func WithGrammar(g *grammar.Grammar) GenerateOption
- func WithMaxTokens(n int) GenerateOption
- func WithRepetitionPenalty(p float64) GenerateOption
- func WithStopStrings(ss ...string) GenerateOption
- func WithTemperature(t float64) GenerateOption
- func WithTopK(k int) GenerateOption
- func WithTopP(p float64) GenerateOption
- type Message
- type Model
- func (m *Model) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) (Response, error)
- func (m *Model) Close() error
- func (m *Model) Config() ModelMetadata
- func (m *Model) Embed(text string) ([]float32, error)
- func (m *Model) EmbeddingWeights() ([]float32, int)
- func (m *Model) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error)
- func (m *Model) GenerateBatch(ctx context.Context, prompts []string, opts ...GenerateOption) ([]string, error)
- func (m *Model) GenerateStream(ctx context.Context, prompt string, handler generate.TokenStream, ...) error
- func (m *Model) Generator() *generate.Generator[float32]
- func (m *Model) Info() *registry.ModelInfo
- func (m *Model) SetEmbeddingWeights(weights []float32, hiddenSize int)
- func (m *Model) SpeculativeGenerate(ctx context.Context, draft *Model, prompt string, draftLen int, ...) (string, error)
- func (m *Model) Tokenizer() tokenizer.Tokenizer
- type ModelMetadata
- type Option
- func WithBackend(backend string) Option
- func WithCacheDir(dir string) Option
- func WithDType(dtype string) Option
- func WithDevice(device string) Option
- func WithKVDtype(dtype string) Option
- func WithMaxSeqLen(n int) Option
- func WithMmap(enabled bool) Option
- func WithPrecision(precision string) Option
- func WithRegistry(r registry.ModelRegistry) Option
- type Response
- type RopeScalingConfig
- type ShapeRange
- type TRTInferenceEngine
- type UnsupportedOpError
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ConvertGraphToTRT ¶
func ConvertGraphToTRT(g *graph.Graph[float32], workspaceBytes int, fp16 bool, dynamicShapes *DynamicShapeConfig) (*trtConversionResult, error)
ConvertGraphToTRT walks a graph in topological order and maps each node to a TensorRT layer. Returns serialized engine bytes or an UnsupportedOpError if the graph contains operations that cannot be converted. If dynamicShapes is non-nil, an optimization profile is created with the specified min/opt/max dimensions for each input.
func LoadTRTEngine ¶
LoadTRTEngine reads a serialized TensorRT engine from the cache. Returns nil, nil on cache miss (file not found).
func RegisterAlias ¶
func RegisterAlias(shortName, repoID string)
RegisterAlias adds a custom short name -> HuggingFace repo ID mapping.
func ResolveAlias ¶
ResolveAlias returns the HuggingFace repo ID for a short alias. If the name is not an alias, it is returned unchanged.
func SaveTRTEngine ¶
SaveTRTEngine writes a serialized TensorRT engine to the cache directory.
func TRTCacheKey ¶
TRTCacheKey builds a deterministic cache key from model ID, precision, and GPU architecture. The key is a hex SHA-256 hash to avoid filesystem issues with long or special-character model IDs.
Types ¶
type ArchConfigRegistry ¶
type ArchConfigRegistry struct {
// contains filtered or unexported fields
}
ArchConfigRegistry maps model_type strings to config parsers.
func DefaultArchConfigRegistry ¶
func DefaultArchConfigRegistry() *ArchConfigRegistry
DefaultArchConfigRegistry returns a registry with all built-in parsers registered.
func (*ArchConfigRegistry) Parse ¶
func (r *ArchConfigRegistry) Parse(raw map[string]interface{}) (*ModelMetadata, error)
Parse dispatches to the registered parser for the model_type in raw, or falls back to generic field extraction for unknown types.
func (*ArchConfigRegistry) Register ¶
func (r *ArchConfigRegistry) Register(modelType string, parser ConfigParser)
Register adds a parser for the given model type.
type ConfigParser ¶
type ConfigParser func(raw map[string]interface{}) (*ModelMetadata, error)
ConfigParser parses a raw JSON map (from config.json) into ModelMetadata.
type ConstantValueGetter ¶
type ConstantValueGetter interface {
GetValue() *tensor.TensorNumeric[float32]
}
ConstantValueGetter is an interface for nodes that hold constant tensor data.
type DTypeSetter ¶
DTypeSetter is implemented by engines that support setting compute precision.
type DynamicShapeConfig ¶
type DynamicShapeConfig struct {
// InputShapes maps input index (0-based) to its shape range.
InputShapes []ShapeRange
}
DynamicShapeConfig specifies per-input shape ranges for TensorRT optimization profiles. When non-nil, the converter creates an optimization profile that allows variable-size inputs within the specified ranges.
type GGUFModel ¶
type GGUFModel struct {
Config *gguf.ModelConfig
Tensors map[string]*tensor.TensorNumeric[float32]
File *gguf.File
}
GGUFModel holds a loaded GGUF model's configuration and tensors. This is an intermediate representation; full inference requires an architecture-specific graph builder to convert these into a computation graph.
func LoadGGUF ¶
LoadGGUF loads a GGUF model file and returns its configuration and tensors. Tensor names are mapped from GGUF convention (blk.N.attn_q.weight) to Zerfoo canonical names (model.layers.N.self_attn.q_proj.weight).
func (*GGUFModel) ToModelMetadata ¶
func (m *GGUFModel) ToModelMetadata() *ModelMetadata
ToModelMetadata converts a GGUF model config to inference.ModelMetadata.
type GenerateOption ¶
type GenerateOption func(*generate.SamplingConfig)
GenerateOption configures a generation call.
func WithGrammar ¶
func WithGrammar(g *grammar.Grammar) GenerateOption
WithGrammar sets a grammar state machine for constrained decoding. When set, a token mask is applied at each sampling step to restrict output to tokens that are valid according to the grammar.
func WithMaxTokens ¶
func WithMaxTokens(n int) GenerateOption
WithMaxTokens sets the maximum number of tokens to generate.
func WithRepetitionPenalty ¶
func WithRepetitionPenalty(p float64) GenerateOption
WithRepetitionPenalty sets the repetition penalty factor.
func WithStopStrings ¶
func WithStopStrings(ss ...string) GenerateOption
WithStopStrings sets strings that stop generation.
func WithTemperature ¶
func WithTemperature(t float64) GenerateOption
WithTemperature sets the sampling temperature.
func WithTopP ¶
func WithTopP(p float64) GenerateOption
WithTopP sets the top-P (nucleus) sampling parameter.
type Model ¶
type Model struct {
// contains filtered or unexported fields
}
Model is a loaded model ready for generation.
func NewTestModel ¶
func NewTestModel( gen *generate.Generator[float32], tok tokenizer.Tokenizer, eng compute.Engine[float32], meta ModelMetadata, info *registry.ModelInfo, ) *Model
NewTestModel constructs a Model from pre-built components. Intended for use in external test packages that need a Model without going through the full Load pipeline.
func (*Model) Chat ¶
func (m *Model) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) (Response, error)
Chat formats messages using the model's chat template and generates a response. Sessions are pooled to preserve CUDA graph replay.
func (*Model) Close ¶
Close releases resources held by the model. If the model was loaded on a GPU, this frees the CUDA engine's handles, pool, and stream. If loaded with mmap, this releases the memory mapping.
func (*Model) Embed ¶
Embed returns an L2-normalized embedding vector for the given text by looking up token embeddings from the model's embedding table and mean-pooling them.
func (*Model) EmbeddingWeights ¶
EmbeddingWeights returns the flattened token embedding table and the hidden dimension. Returns nil, 0 if embeddings are not available.
func (*Model) Generate ¶
func (m *Model) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error)
Generate produces text from a prompt. Sessions are pooled to reuse GPU memory addresses, enabling CUDA graph replay across calls. Concurrent Generate calls get separate sessions from the pool.
func (*Model) GenerateBatch ¶
func (m *Model) GenerateBatch(ctx context.Context, prompts []string, opts ...GenerateOption) ([]string, error)
GenerateBatch processes multiple prompts concurrently and returns the generated text for each prompt. Results are returned in the same order as the input prompts. If a prompt fails, its corresponding error is non-nil.
[Deviation: Architectural] Used parallel goroutines instead of shared PagedKV decode — full multi-seq requires deeper Generator refactor.
func (*Model) GenerateStream ¶
func (m *Model) GenerateStream(ctx context.Context, prompt string, handler generate.TokenStream, opts ...GenerateOption) error
GenerateStream delivers tokens one at a time via a callback. Sessions are pooled to preserve GPU memory addresses for CUDA graph replay.
func (*Model) SetEmbeddingWeights ¶
SetEmbeddingWeights sets the token embedding table for Embed(). weights is a flattened [vocabSize, hiddenSize] matrix.
func (*Model) SpeculativeGenerate ¶
func (m *Model) SpeculativeGenerate( ctx context.Context, draft *Model, prompt string, draftLen int, opts ...GenerateOption, ) (string, error)
SpeculativeGenerate runs speculative decoding using this model as the target and the draft model for token proposal. draftLen controls how many tokens are proposed per verification step.
type ModelMetadata ¶
type ModelMetadata struct {
Architecture string `json:"architecture"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"`
NumLayers int `json:"num_layers"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
EOSTokenID int `json:"eos_token_id"`
BOSTokenID int `json:"bos_token_id"`
ChatTemplate string `json:"chat_template"`
// Extended fields for multi-architecture support.
IntermediateSize int `json:"intermediate_size"`
NumQueryHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RopeTheta float64 `json:"rope_theta"`
RopeScaling *RopeScalingConfig `json:"rope_scaling,omitempty"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
SlidingWindow int `json:"sliding_window"`
AttentionBias bool `json:"attention_bias"`
PartialRotaryFactor float64 `json:"partial_rotary_factor"`
// DeepSeek MLA and MoE fields.
KVLoRADim int `json:"kv_lora_rank"`
QLoRADim int `json:"q_lora_rank"`
QKRopeHeadDim int `json:"qk_rope_head_dim"`
NumExperts int `json:"num_experts"`
NumExpertsPerToken int `json:"num_experts_per_tok"`
}
ModelMetadata holds model configuration loaded from config.json.
type Option ¶
type Option func(*loadOptions)
Option configures model loading.
func WithBackend ¶
WithBackend selects the inference backend. Supported values: "" or "default" for the standard Engine path, "tensorrt" for TensorRT-optimized inference. TensorRT requires the cuda build tag and a CUDA device.
func WithCacheDir ¶
WithCacheDir sets the model cache directory.
func WithDType ¶
WithDType sets the compute precision for the GPU engine. Supported values: "" or "fp32" for full precision, "fp16" for FP16 compute. FP16 mode converts activations F32->FP16 before GPU kernels and back after. Has no effect on CPU engines.
func WithDevice ¶
WithDevice sets the compute device ("cpu" or "cuda").
func WithKVDtype ¶
WithKVDtype sets the KV cache storage dtype. Supported: "fp32" (default), "fp16". FP16 halves KV cache bandwidth by storing keys/values in half precision.
func WithMaxSeqLen ¶
WithMaxSeqLen overrides the model's default max sequence length.
func WithMmap ¶
WithMmap enables memory-mapped model loading. When true, the ZMF file is mapped into memory using syscall.Mmap instead of os.ReadFile, avoiding heap allocation for model weights. Only supported on unix platforms.
func WithPrecision ¶
WithPrecision sets the compute precision for the TensorRT backend. Supported values: "" or "fp32" for full precision, "fp16" for half precision. Has no effect when the backend is not "tensorrt".
func WithRegistry ¶
func WithRegistry(r registry.ModelRegistry) Option
WithRegistry provides a custom model registry.
type RopeScalingConfig ¶
type RopeScalingConfig struct {
Type string `json:"type"`
Factor float64 `json:"factor"`
OriginalMaxPositionEmbeddings int `json:"original_max_position_embeddings"`
}
RopeScalingConfig holds configuration for RoPE scaling methods (e.g., YaRN).
type ShapeRange ¶
ShapeRange defines min/opt/max dimensions for a single input tensor. Used with DynamicShapeConfig to support variable-size inputs.
type TRTInferenceEngine ¶
type TRTInferenceEngine struct {
// contains filtered or unexported fields
}
TRTInferenceEngine holds a TensorRT engine and execution context for inference. It wraps the serialized engine, providing a Forward method that mirrors the graph forward pass but runs through TensorRT.
func (*TRTInferenceEngine) Close ¶
func (e *TRTInferenceEngine) Close() error
Close releases all TensorRT resources.
func (*TRTInferenceEngine) Forward ¶
func (e *TRTInferenceEngine) Forward(inputs []*tensor.TensorNumeric[float32], outputSize int) (*tensor.TensorNumeric[float32], error)
Forward runs inference through TensorRT with the given input tensors. Input tensors must already be on GPU.
type UnsupportedOpError ¶
type UnsupportedOpError struct {
Ops []string
}
UnsupportedOpError lists the operations that cannot be converted to TensorRT.
func (*UnsupportedOpError) Error ¶
func (e *UnsupportedOpError) Error() string