Documentation
¶
Overview ¶
Package inference provides a high-level API for loading models and generating text with minimal boilerplate.
Index ¶
- func ConvertGraphToTRT(g *graph.Graph[float32], workspaceBytes int, fp16 bool, ...) (*trtConversionResult, error)
- func LoadTRTEngine(key string) ([]byte, error)
- func RegisterAlias(shortName, repoID string)
- func ResolveAlias(name string) string
- func SaveTRTEngine(key string, data []byte) error
- func TRTCacheKey(modelID, precision string) (string, error)
- type ArchConfigRegistry
- type ConfigParser
- type ConstantValueGetter
- type DTypeSetter
- type DynamicShapeConfig
- type GGUFModel
- type GenerateOption
- type Message
- type Model
- func (m *Model) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) (Response, error)
- func (m *Model) Close() error
- func (m *Model) Config() ModelMetadata
- func (m *Model) Embed(ctx context.Context, text string) ([]float32, error)
- func (m *Model) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error)
- func (m *Model) GenerateStream(ctx context.Context, prompt string, handler generate.TokenStream, ...) error
- func (m *Model) Generator() *generate.Generator[float32]
- func (m *Model) Info() *registry.ModelInfo
- func (m *Model) SpeculativeGenerate(ctx context.Context, draft *Model, prompt string, draftLen int, ...) (string, error)
- func (m *Model) Tokenizer() tokenizer.Tokenizer
- type ModelMetadata
- type Option
- func WithBackend(backend string) Option
- func WithCacheDir(dir string) Option
- func WithDType(dtype string) Option
- func WithDevice(device string) Option
- func WithKVDtype(dtype string) Option
- func WithMaxSeqLen(n int) Option
- func WithMmap(enabled bool) Option
- func WithPrecision(precision string) Option
- func WithRegistry(r registry.ModelRegistry) Option
- type Response
- type RopeScalingConfig
- type ShapeRange
- type TRTInferenceEngine
- type UnsupportedOpError
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ConvertGraphToTRT ¶
func ConvertGraphToTRT(g *graph.Graph[float32], workspaceBytes int, fp16 bool, dynamicShapes *DynamicShapeConfig) (*trtConversionResult, error)
ConvertGraphToTRT walks a graph in topological order and maps each node to a TensorRT layer. Returns serialized engine bytes or an UnsupportedOpError if the graph contains operations that cannot be converted. If dynamicShapes is non-nil, an optimization profile is created with the specified min/opt/max dimensions for each input.
func LoadTRTEngine ¶
LoadTRTEngine reads a serialized TensorRT engine from the cache. Returns nil, nil on cache miss (file not found).
func RegisterAlias ¶
func RegisterAlias(shortName, repoID string)
RegisterAlias adds a custom short name -> HuggingFace repo ID mapping.
func ResolveAlias ¶
ResolveAlias returns the HuggingFace repo ID for a short alias. If the name is not an alias, it is returned unchanged.
func SaveTRTEngine ¶
SaveTRTEngine writes a serialized TensorRT engine to the cache directory.
func TRTCacheKey ¶
TRTCacheKey builds a deterministic cache key from model ID, precision, and GPU architecture. The key is a hex SHA-256 hash to avoid filesystem issues with long or special-character model IDs.
Types ¶
type ArchConfigRegistry ¶
type ArchConfigRegistry struct {
// contains filtered or unexported fields
}
ArchConfigRegistry maps model_type strings to config parsers.
func DefaultArchConfigRegistry ¶
func DefaultArchConfigRegistry() *ArchConfigRegistry
DefaultArchConfigRegistry returns a registry with all built-in parsers registered.
func (*ArchConfigRegistry) Parse ¶
func (r *ArchConfigRegistry) Parse(raw map[string]interface{}) (*ModelMetadata, error)
Parse dispatches to the registered parser for the model_type in raw, or falls back to generic field extraction for unknown types.
func (*ArchConfigRegistry) Register ¶
func (r *ArchConfigRegistry) Register(modelType string, parser ConfigParser)
Register adds a parser for the given model type.
type ConfigParser ¶
type ConfigParser func(raw map[string]interface{}) (*ModelMetadata, error)
ConfigParser parses a raw JSON map (from config.json) into ModelMetadata.
type ConstantValueGetter ¶
type ConstantValueGetter interface {
GetValue() *tensor.TensorNumeric[float32]
}
ConstantValueGetter is an interface for nodes that hold constant tensor data.
type DTypeSetter ¶
DTypeSetter is implemented by engines that support setting compute precision.
type DynamicShapeConfig ¶
type DynamicShapeConfig struct {
// InputShapes maps input index (0-based) to its shape range.
InputShapes []ShapeRange
}
DynamicShapeConfig specifies per-input shape ranges for TensorRT optimization profiles. When non-nil, the converter creates an optimization profile that allows variable-size inputs within the specified ranges.
type GGUFModel ¶
type GGUFModel struct {
Config *gguf.ModelConfig
Tensors map[string]*tensor.TensorNumeric[float32]
File *gguf.File
}
GGUFModel holds a loaded GGUF model's configuration and tensors. This is an intermediate representation; full inference requires an architecture-specific graph builder to convert these into a computation graph.
func LoadGGUF ¶
LoadGGUF loads a GGUF model file and returns its configuration and tensors. Tensor names are mapped from GGUF convention (blk.N.attn_q.weight) to Zerfoo canonical names (model.layers.N.self_attn.q_proj.weight).
func (*GGUFModel) ToModelMetadata ¶
func (m *GGUFModel) ToModelMetadata() *ModelMetadata
ToModelMetadata converts a GGUF model config to inference.ModelMetadata.
type GenerateOption ¶
type GenerateOption func(*generate.SamplingConfig)
GenerateOption configures a generation call.
func WithMaxTokens ¶
func WithMaxTokens(n int) GenerateOption
WithMaxTokens sets the maximum number of tokens to generate.
func WithRepetitionPenalty ¶
func WithRepetitionPenalty(p float64) GenerateOption
WithRepetitionPenalty sets the repetition penalty factor.
func WithStopStrings ¶
func WithStopStrings(ss ...string) GenerateOption
WithStopStrings sets strings that stop generation.
func WithTemperature ¶
func WithTemperature(t float64) GenerateOption
WithTemperature sets the sampling temperature.
func WithTopP ¶
func WithTopP(p float64) GenerateOption
WithTopP sets the top-P (nucleus) sampling parameter.
type Model ¶
type Model struct {
// contains filtered or unexported fields
}
Model is a loaded model ready for generation.
func NewTestModel ¶
func NewTestModel( gen *generate.Generator[float32], tok tokenizer.Tokenizer, eng compute.Engine[float32], meta ModelMetadata, info *registry.ModelInfo, ) *Model
NewTestModel constructs a Model from pre-built components. Intended for use in external test packages that need a Model without going through the full Load pipeline.
func (*Model) Chat ¶
func (m *Model) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) (Response, error)
Chat formats messages using the model's chat template and generates a response.
func (*Model) Close ¶
Close releases resources held by the model. If the model was loaded on a GPU, this frees the CUDA engine's handles, pool, and stream. If loaded with mmap, this releases the memory mapping.
func (*Model) Embed ¶
Embed returns a float32 embedding vector for the given text. It runs the model forward and mean-pools the last layer's hidden states.
func (*Model) Generate ¶
func (m *Model) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error)
Generate produces text from a prompt.
func (*Model) GenerateStream ¶
func (m *Model) GenerateStream(ctx context.Context, prompt string, handler generate.TokenStream, opts ...GenerateOption) error
GenerateStream delivers tokens one at a time via a callback.
func (*Model) SpeculativeGenerate ¶
func (m *Model) SpeculativeGenerate( ctx context.Context, draft *Model, prompt string, draftLen int, opts ...GenerateOption, ) (string, error)
SpeculativeGenerate runs speculative decoding using this model as the target and the draft model for token proposal. draftLen controls how many tokens are proposed per verification step.
type ModelMetadata ¶
type ModelMetadata struct {
Architecture string `json:"architecture"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"`
NumLayers int `json:"num_layers"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
EOSTokenID int `json:"eos_token_id"`
BOSTokenID int `json:"bos_token_id"`
ChatTemplate string `json:"chat_template"`
// Extended fields for multi-architecture support.
IntermediateSize int `json:"intermediate_size"`
NumQueryHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RopeTheta float64 `json:"rope_theta"`
RopeScaling *RopeScalingConfig `json:"rope_scaling,omitempty"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
SlidingWindow int `json:"sliding_window"`
AttentionBias bool `json:"attention_bias"`
PartialRotaryFactor float64 `json:"partial_rotary_factor"`
// DeepSeek MLA and MoE fields.
KVLoRADim int `json:"kv_lora_rank"`
QLoRADim int `json:"q_lora_rank"`
QKRopeHeadDim int `json:"qk_rope_head_dim"`
NumExperts int `json:"num_experts"`
NumExpertsPerToken int `json:"num_experts_per_tok"`
}
ModelMetadata holds model configuration loaded from config.json.
type Option ¶
type Option func(*loadOptions)
Option configures model loading.
func WithBackend ¶
WithBackend selects the inference backend. Supported values: "" or "default" for the standard Engine path, "tensorrt" for TensorRT-optimized inference. TensorRT requires the cuda build tag and a CUDA device.
func WithCacheDir ¶
WithCacheDir sets the model cache directory.
func WithDType ¶
WithDType sets the compute precision for the GPU engine. Supported values: "" or "fp32" for full precision, "fp16" for FP16 compute. FP16 mode converts activations F32->FP16 before GPU kernels and back after. Has no effect on CPU engines.
func WithDevice ¶
WithDevice sets the compute device ("cpu" or "cuda").
func WithKVDtype ¶
WithKVDtype sets the KV cache storage dtype. Supported: "fp32" (default), "fp16". FP16 halves KV cache bandwidth by storing keys/values in half precision.
func WithMaxSeqLen ¶
WithMaxSeqLen overrides the model's default max sequence length.
func WithMmap ¶
WithMmap enables memory-mapped model loading. When true, the ZMF file is mapped into memory using syscall.Mmap instead of os.ReadFile, avoiding heap allocation for model weights. Only supported on unix platforms.
func WithPrecision ¶
WithPrecision sets the compute precision for the TensorRT backend. Supported values: "" or "fp32" for full precision, "fp16" for half precision. Has no effect when the backend is not "tensorrt".
func WithRegistry ¶
func WithRegistry(r registry.ModelRegistry) Option
WithRegistry provides a custom model registry.
type RopeScalingConfig ¶
type RopeScalingConfig struct {
Type string `json:"type"`
Factor float64 `json:"factor"`
OriginalMaxPositionEmbeddings int `json:"original_max_position_embeddings"`
}
RopeScalingConfig holds configuration for RoPE scaling methods (e.g., YaRN).
type ShapeRange ¶
ShapeRange defines min/opt/max dimensions for a single input tensor. Used with DynamicShapeConfig to support variable-size inputs.
type TRTInferenceEngine ¶
type TRTInferenceEngine struct {
// contains filtered or unexported fields
}
TRTInferenceEngine holds a TensorRT engine and execution context for inference. It wraps the serialized engine, providing a Forward method that mirrors the graph forward pass but runs through TensorRT.
func (*TRTInferenceEngine) Close ¶
func (e *TRTInferenceEngine) Close() error
Close releases all TensorRT resources.
func (*TRTInferenceEngine) Forward ¶
func (e *TRTInferenceEngine) Forward(inputs []*tensor.TensorNumeric[float32], outputSize int) (*tensor.TensorNumeric[float32], error)
Forward runs inference through TensorRT with the given input tensors. Input tensors must already be on GPU.
type UnsupportedOpError ¶
type UnsupportedOpError struct {
Ops []string
}
UnsupportedOpError lists the operations that cannot be converted to TensorRT.
func (*UnsupportedOpError) Error ¶
func (e *UnsupportedOpError) Error() string