pjrt

package
v1.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 3, 2026 License: Apache-2.0 Imports: 13 Imported by: 0

Documentation

Overview

Package pjrt provides purego bindings for the PJRT C API.

PJRT (Portable JAX Runtime) is OpenXLA's hardware plugin API. A single PJRT integration gives access to every accelerator that ships a PJRT plugin (CPU, CUDA, TPU, Trainium, Metal) without per-backend kernel work.

All bindings use dlopen/dlsym via the cuda package's exported helpers — zero CGo.

Index

Constants

View Source
const DefaultCacheDir = ".cache/zerfoo/pjrt"

DefaultCacheDir is the default directory for cached PJRT executables.

View Source
const DefaultMaxCacheSize int64 = 2 << 30

DefaultMaxCacheSize is the default maximum cache size in bytes (2 GB).

Variables

This section is empty.

Functions

func Key

func Key(stablehloMLIR, platformName string) string

Key returns the content-addressed cache key for the given StableHLO program and platform name: SHA256(program + "\x00" + platform).

func ToHostSlice

func ToHostSlice[T any](b *Buffer, dst []T) error

ToHostSlice is a typed convenience wrapper around ToHost that copies device buffer data into a pre-allocated Go slice of the appropriate type.

Types

type Buffer

type Buffer struct {
	// contains filtered or unexported fields
}

Buffer wraps a PJRT_Buffer handle and provides Go-friendly methods for device-to-host readback, metadata queries, and lifecycle management.

Buffers must be closed with Close() when no longer needed. Double-close is a safe no-op (finalizer safety).

func BufferFromHost

func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device, opts ...BufferOption) (*Buffer, error)

BufferFromHost transfers a Go slice to a PJRT device buffer.

The data slice is copied during the call (ImmutableOnlyDuringCall semantics by default). The shape describes the tensor dimensions. The target device determines where the buffer is placed.

Use WithDonation() to enable buffer donation for KV cache optimization.

func (*Buffer) Close

func (b *Buffer) Close() error

Close destroys the PJRT buffer handle and releases associated resources. Safe to call multiple times (double-close is a no-op for finalizer safety).

func (*Buffer) Delete

func (b *Buffer) Delete() error

Delete marks the buffer for deletion. The runtime may release the device memory immediately or defer it. After Delete, the buffer handle should not be used for data access, but Destroy is still required for handle cleanup.

func (*Buffer) Dtype

func (b *Buffer) Dtype() (ElementType, error)

Dtype returns the PJRT element type of this buffer.

func (*Buffer) Handle

func (b *Buffer) Handle() uintptr

Handle returns the raw PJRT_Buffer pointer.

func (*Buffer) OnDeviceSizeInBytes

func (b *Buffer) OnDeviceSizeInBytes() (int64, error)

OnDeviceSizeInBytes returns the buffer's memory footprint on the device.

func (*Buffer) ReadyEvent

func (b *Buffer) ReadyEvent() (uintptr, error)

ReadyEvent returns the PJRT_Event handle for this buffer's readiness. The caller is responsible for destroying the event via awaitEvent or destroyEvent.

func (*Buffer) Shape

func (b *Buffer) Shape() ([]int, error)

Shape returns the dimensions of this buffer.

func (*Buffer) ToHost

func (b *Buffer) ToHost(dst []byte) error

ToHost copies device buffer data back to a pre-allocated Go slice.

The destination slice must have exactly the right number of elements (product of Shape dimensions). The call blocks until the readback completes (PJRT_Event_Await).

type BufferOption

type BufferOption func(*bufferConfig)

BufferOption configures BufferFromHost behavior.

func WithDonation

func WithDonation() BufferOption

WithDonation enables buffer donation semantics. The runtime is allowed to take ownership of the host memory, avoiding a copy. The caller must not access the source slice after calling BufferFromHost with this option.

func WithSemantics

func WithSemantics(s HostBufferSemantics) BufferOption

WithSemantics sets the host buffer semantics for the transfer.

type Cache

type Cache struct {
	// contains filtered or unexported fields
}

Cache stores serialized PJRT executables keyed by a content hash of the StableHLO program text and platform name. It provides LRU eviction when the total size exceeds MaxSize.

func NewCache

func NewCache(opts ...CacheOption) *Cache

NewCache creates a new executable cache. The cache directory is created on first Put if it does not already exist.

func (*Cache) Clear

func (c *Cache) Clear() error

Clear removes all cached entries.

func (*Cache) Dir

func (c *Cache) Dir() string

Dir returns the cache directory path.

func (*Cache) Evict

func (c *Cache) Evict()

Evict removes the least-recently-used entries until total cache size is within MaxSize.

func (*Cache) Get

func (c *Cache) Get(key string) ([]byte, error)

Get looks up a cached serialized executable by key. If found, the raw bytes are returned (caller must DeserializeAndLoad). Returns nil, nil on cache miss.

func (*Cache) Put

func (c *Cache) Put(key string, data []byte) error

Put stores serialized executable bytes under the given key. If storing the new entry would exceed MaxSize, the least-recently-used entries are evicted first.

func (*Cache) Stats

func (c *Cache) Stats() CacheStats

Stats returns current cache statistics.

type CacheOption

type CacheOption func(*Cache)

CacheOption configures a Cache.

func WithCacheDir

func WithCacheDir(dir string) CacheOption

WithCacheDir sets the cache directory. If empty, defaults to $ZERFOO_PJRT_CACHE or ~/.cache/zerfoo/pjrt/.

func WithMaxCacheSize

func WithMaxCacheSize(n int64) CacheOption

WithMaxCacheSize sets the maximum total size of cached files in bytes.

type CacheStats

type CacheStats struct {
	Hits   int64
	Misses int64
	Size   int64 // total bytes on disk
	Files  int   // number of cached entries
}

CacheStats holds cache hit/miss/size statistics.

type Client

type Client struct {
	// contains filtered or unexported fields
}

Client wraps a PJRT_Client handle and provides Go-friendly methods for querying the runtime: platform name/version, device enumeration.

func NewClient

func NewClient(lib *PJRTLib, opts ...ClientOption) (*Client, error)

NewClient creates a PJRT client using the given plugin library. The client must be closed with Close() when no longer needed.

func (*Client) AddressableDevices

func (c *Client) AddressableDevices() ([]*Device, error)

AddressableDevices returns devices that this client can directly interact with.

func (*Client) Close

func (c *Client) Close() error

Close destroys the PJRT client and releases associated resources. Safe to call multiple times.

func (*Client) Compile

func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error)

Compile compiles a StableHLO MLIR text program and returns a LoadedExecutable. The executable is ready for execution and its output metadata (number of outputs, element types, dimensions) is queried and cached immediately.

func (*Client) DeserializeAndLoad

func (c *Client) DeserializeAndLoad(data []byte) (*LoadedExecutable, error)

DeserializeAndLoad restores a previously serialized executable, returning a LoadedExecutable ready for execution. This skips the compilation step entirely, which can save significant time for large models.

The serialized data must have been produced by Serialize() on the same plugin and hardware platform.

func (*Client) Devices

func (c *Client) Devices() ([]*Device, error)

Devices returns all devices known to the client (including non-addressable).

func (*Client) Handle

func (c *Client) Handle() uintptr

Handle returns the raw PJRT_Client pointer for use by other PJRT wrappers.

func (*Client) Lib

func (c *Client) Lib() *PJRTLib

Lib returns the PJRTLib associated with this client.

func (*Client) PlatformName

func (c *Client) PlatformName() (string, error)

PlatformName returns the name of the platform (e.g. "cpu", "cuda", "tpu").

func (*Client) PlatformVersion

func (c *Client) PlatformVersion() (string, error)

PlatformVersion returns the version string of the platform.

type ClientOption

type ClientOption func(*clientConfig)

ClientOption configures client creation.

func WithCreateOptions

func WithCreateOptions(opts uintptr) ClientOption

WithCreateOptions sets plugin-specific creation options.

type Device

type Device struct {
	// contains filtered or unexported fields
}

Device wraps a PJRT_Device handle and provides methods for querying device properties: ID, kind, addressability, hardware ID.

Device handles are owned by the Client and must not be destroyed independently — they become invalid when the Client is closed.

func (*Device) Handle

func (d *Device) Handle() uintptr

Handle returns the raw PJRT_Device pointer.

func (*Device) ID

func (d *Device) ID() (int, error)

ID returns the unique device ID within the client.

func (*Device) IsAddressable

func (d *Device) IsAddressable() (bool, error)

IsAddressable returns true if this device can be directly accessed by the client (i.e. it is local to this process).

func (*Device) Kind

func (d *Device) Kind() (string, error)

Kind returns the device kind string (e.g. "cpu", "gpu", "tpu").

func (*Device) LocalHardwareId

func (d *Device) LocalHardwareId() (int, error)

LocalHardwareId returns the hardware-level device ID. This is useful for multi-device systems (e.g. GPU index on a multi-GPU node).

type ElementType

type ElementType int32

ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API.

const (
	ElementTypeInvalid ElementType = 0
	ElementTypePRED    ElementType = 1  // bool
	ElementTypeS8      ElementType = 2  // int8
	ElementTypeS16     ElementType = 3  // int16
	ElementTypeS32     ElementType = 4  // int32
	ElementTypeS64     ElementType = 5  // int64
	ElementTypeU8      ElementType = 6  // uint8
	ElementTypeU16     ElementType = 7  // uint16
	ElementTypeU32     ElementType = 8  // uint32
	ElementTypeU64     ElementType = 9  // uint64
	ElementTypeF16     ElementType = 10 // float16
	ElementTypeF32     ElementType = 11 // float32
	ElementTypeF64     ElementType = 12 // float64
	ElementTypeBF16    ElementType = 16 // bfloat16
	ElementTypeF8E4M3  ElementType = 20 // float8 E4M3FN
)

func GoTypeToElementType

func GoTypeToElementType[T any]() ElementType

GoTypeToElementType maps a Go type (via its size and kind) to the corresponding PJRT element type.

func (ElementType) ByteSize

func (t ElementType) ByteSize() int

ByteSize returns the size in bytes of a single element of this type.

func (ElementType) String

func (t ElementType) String() string

String returns the PJRT element type name.

type ExecOption

type ExecOption func(*execConfig)

ExecOption configures Execute behavior.

func WithDeviceOrdinal

func WithDeviceOrdinal(ordinal int) ExecOption

WithDeviceOrdinal selects which device to execute on.

func WithInputDonation

func WithInputDonation(donated []bool) ExecOption

WithInputDonation marks specific inputs for buffer donation. donated[i] == true means input i may be consumed by the runtime.

type HostBufferSemantics

type HostBufferSemantics int32

HostBufferSemantics controls how PJRT handles the host data pointer during BufferFromHostBuffer.

const (
	// HostBufferImmutableOnlyDuringCall means PJRT copies the data during
	// the call and the host buffer can be modified immediately after return.
	HostBufferImmutableOnlyDuringCall HostBufferSemantics = 0

	// HostBufferImmutableUntilTransferCompletes means the host buffer must
	// remain valid until the returned event completes. Avoids a copy on
	// some backends.
	HostBufferImmutableUntilTransferCompletes HostBufferSemantics = 1

	// HostBufferImmutableZeroCopy means PJRT uses the host memory directly
	// (zero-copy). The host buffer must remain valid for the buffer lifetime.
	HostBufferImmutableZeroCopy HostBufferSemantics = 2
)

type LoadedExecutable

type LoadedExecutable struct {
	// contains filtered or unexported fields
}

LoadedExecutable wraps a PJRT_LoadedExecutable handle returned by Client.Compile. It holds the compiled StableHLO program ready for execution on the target device.

func (*LoadedExecutable) Close

func (e *LoadedExecutable) Close() error

Close destroys the loaded executable and releases associated resources. Safe to call multiple times.

func (*LoadedExecutable) Execute

func (e *LoadedExecutable) Execute(inputs []*Buffer, opts ...ExecOption) ([]*Buffer, error)

Execute runs the compiled program with the given input buffers and returns the output buffers. The caller owns the returned buffers and must close them when done.

func (*LoadedExecutable) Handle

func (e *LoadedExecutable) Handle() uintptr

Handle returns the raw PJRT_LoadedExecutable pointer.

func (*LoadedExecutable) NumOutputs

func (e *LoadedExecutable) NumOutputs() int

NumOutputs returns the number of outputs the compiled program produces.

func (*LoadedExecutable) OutputDimensions

func (e *LoadedExecutable) OutputDimensions() [][]int64

OutputDimensions returns the dimension arrays for each output. Each entry is a copy of the output's shape.

func (*LoadedExecutable) OutputElementTypes

func (e *LoadedExecutable) OutputElementTypes() []int32

OutputElementTypes returns the PJRT element type codes for each output.

func (*LoadedExecutable) Serialize

func (e *LoadedExecutable) Serialize() ([]byte, error)

Serialize serializes the compiled executable to bytes. The serialized form can be cached to disk and later restored with Client.DeserializeAndLoad, skipping recompilation on subsequent runs with the same model and hardware.

type PJRTLib

type PJRTLib struct {

	// Version reported by the plugin.
	VersionMajor int
	VersionMinor int
	// contains filtered or unexported fields
}

PJRTLib holds a dlopen handle for a PJRT plugin and the resolved function pointers extracted from the PJRT_Api struct returned by GetPjrtApi().

func Load

func Load(pluginName string) (*PJRTLib, error)

Load opens a PJRT plugin shared library and extracts all function pointers from the PJRT_Api struct. pluginName is the bare library filename (e.g. "pjrt_c_api_cpu_plugin.so").

Load searches $PJRT_PLUGIN_PATH, standard system directories, AWS Neuron paths, and Python site-packages.

Returns a clean error if the plugin is not found or if the API version is incompatible.

func (*PJRTLib) Close

func (lib *PJRTLib) Close() error

Close marks the PJRTLib as closed. The dlopen handle is intentionally not released — PJRT plugins are expected to remain loaded for the process lifetime (same as CUDA/cuDNN). Safe to call multiple times.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL