Documentation
¶
Overview ¶
Package pjrt provides purego bindings for the PJRT C API.
PJRT (Portable JAX Runtime) is OpenXLA's hardware plugin API. A single PJRT integration gives access to every accelerator that ships a PJRT plugin (CPU, CUDA, TPU, Trainium, Metal) without per-backend kernel work.
All bindings use dlopen/dlsym via the cuda package's exported helpers — zero CGo.
Index ¶
- Constants
- func Key(stablehloMLIR, platformName string) string
- func ToHostSlice[T any](b *Buffer, dst []T) error
- type Buffer
- func (b *Buffer) Close() error
- func (b *Buffer) Delete() error
- func (b *Buffer) Dtype() (ElementType, error)
- func (b *Buffer) Handle() uintptr
- func (b *Buffer) OnDeviceSizeInBytes() (int64, error)
- func (b *Buffer) ReadyEvent() (uintptr, error)
- func (b *Buffer) Shape() ([]int, error)
- func (b *Buffer) ToHost(dst []byte) error
- type BufferOption
- type Cache
- type CacheOption
- type CacheStats
- type Client
- func (c *Client) AddressableDevices() ([]*Device, error)
- func (c *Client) Close() error
- func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error)
- func (c *Client) DeserializeAndLoad(data []byte) (*LoadedExecutable, error)
- func (c *Client) Devices() ([]*Device, error)
- func (c *Client) Handle() uintptr
- func (c *Client) Lib() *PJRTLib
- func (c *Client) PlatformName() (string, error)
- func (c *Client) PlatformVersion() (string, error)
- type ClientOption
- type Device
- type ElementType
- type ExecOption
- type HostBufferSemantics
- type LoadedExecutable
- func (e *LoadedExecutable) Close() error
- func (e *LoadedExecutable) Execute(inputs []*Buffer, opts ...ExecOption) ([]*Buffer, error)
- func (e *LoadedExecutable) Handle() uintptr
- func (e *LoadedExecutable) NumOutputs() int
- func (e *LoadedExecutable) OutputDimensions() [][]int64
- func (e *LoadedExecutable) OutputElementTypes() []int32
- func (e *LoadedExecutable) Serialize() ([]byte, error)
- type PJRTLib
Constants ¶
const DefaultCacheDir = ".cache/zerfoo/pjrt"
DefaultCacheDir is the default directory for cached PJRT executables.
const DefaultMaxCacheSize int64 = 2 << 30
DefaultMaxCacheSize is the default maximum cache size in bytes (2 GB).
Variables ¶
This section is empty.
Functions ¶
func Key ¶
Key returns the content-addressed cache key for the given StableHLO program and platform name: SHA256(program + "\x00" + platform).
func ToHostSlice ¶
ToHostSlice is a typed convenience wrapper around ToHost that copies device buffer data into a pre-allocated Go slice of the appropriate type.
Types ¶
type Buffer ¶
type Buffer struct {
// contains filtered or unexported fields
}
Buffer wraps a PJRT_Buffer handle and provides Go-friendly methods for device-to-host readback, metadata queries, and lifecycle management.
Buffers must be closed with Close() when no longer needed. Double-close is a safe no-op (finalizer safety).
func BufferFromHost ¶
func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device, opts ...BufferOption) (*Buffer, error)
BufferFromHost transfers a Go slice to a PJRT device buffer.
The data slice is copied during the call (ImmutableOnlyDuringCall semantics by default). The shape describes the tensor dimensions. The target device determines where the buffer is placed.
Use WithDonation() to enable buffer donation for KV cache optimization.
func (*Buffer) Close ¶
Close destroys the PJRT buffer handle and releases associated resources. Safe to call multiple times (double-close is a no-op for finalizer safety).
func (*Buffer) Delete ¶
Delete marks the buffer for deletion. The runtime may release the device memory immediately or defer it. After Delete, the buffer handle should not be used for data access, but Destroy is still required for handle cleanup.
func (*Buffer) Dtype ¶
func (b *Buffer) Dtype() (ElementType, error)
Dtype returns the PJRT element type of this buffer.
func (*Buffer) OnDeviceSizeInBytes ¶
OnDeviceSizeInBytes returns the buffer's memory footprint on the device.
func (*Buffer) ReadyEvent ¶
ReadyEvent returns the PJRT_Event handle for this buffer's readiness. The caller is responsible for destroying the event via awaitEvent or destroyEvent.
type BufferOption ¶
type BufferOption func(*bufferConfig)
BufferOption configures BufferFromHost behavior.
func WithDonation ¶
func WithDonation() BufferOption
WithDonation enables buffer donation semantics. The runtime is allowed to take ownership of the host memory, avoiding a copy. The caller must not access the source slice after calling BufferFromHost with this option.
func WithSemantics ¶
func WithSemantics(s HostBufferSemantics) BufferOption
WithSemantics sets the host buffer semantics for the transfer.
type Cache ¶
type Cache struct {
// contains filtered or unexported fields
}
Cache stores serialized PJRT executables keyed by a content hash of the StableHLO program text and platform name. It provides LRU eviction when the total size exceeds MaxSize.
func NewCache ¶
func NewCache(opts ...CacheOption) *Cache
NewCache creates a new executable cache. The cache directory is created on first Put if it does not already exist.
func (*Cache) Evict ¶
func (c *Cache) Evict()
Evict removes the least-recently-used entries until total cache size is within MaxSize.
func (*Cache) Get ¶
Get looks up a cached serialized executable by key. If found, the raw bytes are returned (caller must DeserializeAndLoad). Returns nil, nil on cache miss.
type CacheOption ¶
type CacheOption func(*Cache)
CacheOption configures a Cache.
func WithCacheDir ¶
func WithCacheDir(dir string) CacheOption
WithCacheDir sets the cache directory. If empty, defaults to $ZERFOO_PJRT_CACHE or ~/.cache/zerfoo/pjrt/.
func WithMaxCacheSize ¶
func WithMaxCacheSize(n int64) CacheOption
WithMaxCacheSize sets the maximum total size of cached files in bytes.
type CacheStats ¶
type CacheStats struct {
Hits int64
Misses int64
Size int64 // total bytes on disk
Files int // number of cached entries
}
CacheStats holds cache hit/miss/size statistics.
type Client ¶
type Client struct {
// contains filtered or unexported fields
}
Client wraps a PJRT_Client handle and provides Go-friendly methods for querying the runtime: platform name/version, device enumeration.
func NewClient ¶
func NewClient(lib *PJRTLib, opts ...ClientOption) (*Client, error)
NewClient creates a PJRT client using the given plugin library. The client must be closed with Close() when no longer needed.
func (*Client) AddressableDevices ¶
AddressableDevices returns devices that this client can directly interact with.
func (*Client) Close ¶
Close destroys the PJRT client and releases associated resources. Safe to call multiple times.
func (*Client) Compile ¶
func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error)
Compile compiles a StableHLO MLIR text program and returns a LoadedExecutable. The executable is ready for execution and its output metadata (number of outputs, element types, dimensions) is queried and cached immediately.
func (*Client) DeserializeAndLoad ¶
func (c *Client) DeserializeAndLoad(data []byte) (*LoadedExecutable, error)
DeserializeAndLoad restores a previously serialized executable, returning a LoadedExecutable ready for execution. This skips the compilation step entirely, which can save significant time for large models.
The serialized data must have been produced by Serialize() on the same plugin and hardware platform.
func (*Client) Devices ¶
Devices returns all devices known to the client (including non-addressable).
func (*Client) PlatformName ¶
PlatformName returns the name of the platform (e.g. "cpu", "cuda", "tpu").
func (*Client) PlatformVersion ¶
PlatformVersion returns the version string of the platform.
type ClientOption ¶
type ClientOption func(*clientConfig)
ClientOption configures client creation.
func WithCreateOptions ¶
func WithCreateOptions(opts uintptr) ClientOption
WithCreateOptions sets plugin-specific creation options.
type Device ¶
type Device struct {
// contains filtered or unexported fields
}
Device wraps a PJRT_Device handle and provides methods for querying device properties: ID, kind, addressability, hardware ID.
Device handles are owned by the Client and must not be destroyed independently — they become invalid when the Client is closed.
func (*Device) IsAddressable ¶
IsAddressable returns true if this device can be directly accessed by the client (i.e. it is local to this process).
func (*Device) LocalHardwareId ¶
LocalHardwareId returns the hardware-level device ID. This is useful for multi-device systems (e.g. GPU index on a multi-GPU node).
type ElementType ¶
type ElementType int32
ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API.
const ( ElementTypeInvalid ElementType = 0 ElementTypePRED ElementType = 1 // bool ElementTypeS8 ElementType = 2 // int8 ElementTypeS16 ElementType = 3 // int16 ElementTypeS32 ElementType = 4 // int32 ElementTypeS64 ElementType = 5 // int64 ElementTypeU8 ElementType = 6 // uint8 ElementTypeU16 ElementType = 7 // uint16 ElementTypeU32 ElementType = 8 // uint32 ElementTypeU64 ElementType = 9 // uint64 ElementTypeF16 ElementType = 10 // float16 ElementTypeF32 ElementType = 11 // float32 ElementTypeF64 ElementType = 12 // float64 ElementTypeBF16 ElementType = 16 // bfloat16 ElementTypeF8E4M3 ElementType = 20 // float8 E4M3FN )
func GoTypeToElementType ¶
func GoTypeToElementType[T any]() ElementType
GoTypeToElementType maps a Go type (via its size and kind) to the corresponding PJRT element type.
func (ElementType) ByteSize ¶
func (t ElementType) ByteSize() int
ByteSize returns the size in bytes of a single element of this type.
func (ElementType) String ¶
func (t ElementType) String() string
String returns the PJRT element type name.
type ExecOption ¶
type ExecOption func(*execConfig)
ExecOption configures Execute behavior.
func WithDeviceOrdinal ¶
func WithDeviceOrdinal(ordinal int) ExecOption
WithDeviceOrdinal selects which device to execute on.
func WithInputDonation ¶
func WithInputDonation(donated []bool) ExecOption
WithInputDonation marks specific inputs for buffer donation. donated[i] == true means input i may be consumed by the runtime.
type HostBufferSemantics ¶
type HostBufferSemantics int32
HostBufferSemantics controls how PJRT handles the host data pointer during BufferFromHostBuffer.
const ( // HostBufferImmutableOnlyDuringCall means PJRT copies the data during // the call and the host buffer can be modified immediately after return. HostBufferImmutableOnlyDuringCall HostBufferSemantics = 0 // HostBufferImmutableUntilTransferCompletes means the host buffer must // remain valid until the returned event completes. Avoids a copy on // some backends. HostBufferImmutableUntilTransferCompletes HostBufferSemantics = 1 // HostBufferImmutableZeroCopy means PJRT uses the host memory directly // (zero-copy). The host buffer must remain valid for the buffer lifetime. HostBufferImmutableZeroCopy HostBufferSemantics = 2 )
type LoadedExecutable ¶
type LoadedExecutable struct {
// contains filtered or unexported fields
}
LoadedExecutable wraps a PJRT_LoadedExecutable handle returned by Client.Compile. It holds the compiled StableHLO program ready for execution on the target device.
func (*LoadedExecutable) Close ¶
func (e *LoadedExecutable) Close() error
Close destroys the loaded executable and releases associated resources. Safe to call multiple times.
func (*LoadedExecutable) Execute ¶
func (e *LoadedExecutable) Execute(inputs []*Buffer, opts ...ExecOption) ([]*Buffer, error)
Execute runs the compiled program with the given input buffers and returns the output buffers. The caller owns the returned buffers and must close them when done.
func (*LoadedExecutable) Handle ¶
func (e *LoadedExecutable) Handle() uintptr
Handle returns the raw PJRT_LoadedExecutable pointer.
func (*LoadedExecutable) NumOutputs ¶
func (e *LoadedExecutable) NumOutputs() int
NumOutputs returns the number of outputs the compiled program produces.
func (*LoadedExecutable) OutputDimensions ¶
func (e *LoadedExecutable) OutputDimensions() [][]int64
OutputDimensions returns the dimension arrays for each output. Each entry is a copy of the output's shape.
func (*LoadedExecutable) OutputElementTypes ¶
func (e *LoadedExecutable) OutputElementTypes() []int32
OutputElementTypes returns the PJRT element type codes for each output.
func (*LoadedExecutable) Serialize ¶
func (e *LoadedExecutable) Serialize() ([]byte, error)
Serialize serializes the compiled executable to bytes. The serialized form can be cached to disk and later restored with Client.DeserializeAndLoad, skipping recompilation on subsequent runs with the same model and hardware.
type PJRTLib ¶
type PJRTLib struct {
// Version reported by the plugin.
VersionMajor int
VersionMinor int
// contains filtered or unexported fields
}
PJRTLib holds a dlopen handle for a PJRT plugin and the resolved function pointers extracted from the PJRT_Api struct returned by GetPjrtApi().
func Load ¶
Load opens a PJRT plugin shared library and extracts all function pointers from the PJRT_Api struct. pluginName is the bare library filename (e.g. "pjrt_c_api_cpu_plugin.so").
Load searches $PJRT_PLUGIN_PATH, standard system directories, AWS Neuron paths, and Python site-packages.
Returns a clean error if the plugin is not found or if the API version is incompatible.