onnxruntime_go

package module
v0.0.0-...-335f59c Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 17, 2023 License: MIT Imports: 5 Imported by: 0

README

Cross-Platform onnxruntime Wrapper for Go

About

This library seeks to provide an interface for loading and executing neural networks from Go(lang) code, while remaining as simple to use as possible.

The onnxruntime library provides a way to load and execute ONNX-format neural networks, though the library primarily supports C and C++ APIs. Several efforts exist to have written Go(lang) wrappers for the onnxruntime library, but as far as I can tell, none of these existing Go wrappers support Windows. This is due to the fact that Microsoft's onnxruntime library assumes the user will be using the MSVC compiler on Windows systems, while CGo on Windows requires using Mingw.

This wrapper works around the issues by manually loading the onnxruntime shared library, removing any dependency on the onnxruntime source code beyond the header files. Naturally, this approach works equally well on non-Windows systems.

Additionally, this library uses Go's recent addition of generics to support multiple Tensor data types; see the NewTensor or NewEmptyTensor functions.

Requirements

To use this library, you'll need a version of Go with cgo support. If you are not using an amd64 version of Windows or Linux (or if you want to provide your own library for some other reason), you simply need to provide the correct path to the shared library when initializing the wrapper. This is seen in the first few lines of the following example.

Example Usage

The following example illustrates how this library can be used to load and run an ONNX network taking a single input tensor and producing a single output tensor, both of which contain 32-bit floating point values. Note that error handling is omitted; each of the functions returns an err value, which will be non-nil in the case of failure.

import (
    "fmt"
    ort "github.com/yalue/onnxruntime_go"
    "os"
)

func main() {
    // This line may be optional, by default the library will try to load
    // "onnxruntime.dll" on Windows, and "onnxruntime.so" on any other system.
    ort.SetSharedLibraryPath("path/to/onnxruntime.so")

    err := ort.InitializeEnvironment()
    defer ort.DestroyEnvironment()

    // To make it easier to work with the C API, this library requires the user
    // to create all input and output tensors prior to creating the session.
    inputData := []float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}
    inputShape := ort.NewShape(2, 5)
    inputTensor, err := ort.NewTensor(inputShape, inputData)
    defer inputTensor.Destroy()
    // This hypothetical network maps a 2x5 input -> 2x3x4 output.
    outputShape := ort.NewShape(2, 3, 4)
    outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
    defer outputTensor.Destroy()

    session, err := ort.NewSession[float32]("path/to/network.onnx",
        []string{"Input 1 Name"}, []string{"Output 1 Name"},
        []*Tensor[float32]{inputTensor}, []*Tensor[float32]{outputTensor})
    defer session.Destroy()

    // Calling Run() will run the network, reading the current contents of the
    // input tensors and modifying the contents of the output tensors. Simply
    // modify the input tensor's data (available via inputTensor.GetData())
    // before calling Run().
    err = session.Run()

    outputData := outputTensor.GetData()

    // ...
}

The full documentation can be found at pkg.go.dev.

Documentation

Overview

This library wraps the C "onnxruntime" library maintained at https://github.com/microsoft/onnxruntime. It seeks to provide as simple an interface as possible to load and run ONNX-format neural networks from Go code.

Index

Constants

This section is empty.

Variables

View Source
var NotInitializedError error = fmt.Errorf("InitializeRuntime() has either " +
	"not yet been called, or did not return successfully")
View Source
var ShapeOverflowError error = fmt.Errorf("The shape's flattened size " +
	"overflows an int64")
View Source
var ZeroShapeLengthError error = fmt.Errorf("The shape has no dimensions")

Functions

func DestroyEnvironment

func DestroyEnvironment() error

Call this function to cleanup the internal onnxruntime environment when it is no longer needed.

func DisableTelemetry

func DisableTelemetry() error

Disables telemetry events for the onnxruntime environment. Must be called after initializing the environment using InitializeEnvironment(). It is unclear from the onnxruntime docs whether this will cause an error or silently return if telemetry is already disabled.

func EnableTelemetry

func EnableTelemetry() error

Enables telemetry events for the onnxruntime environment. Must be called after initializing the environment using InitializeEnvironment(). It is unclear from the onnxruntime docs whether this will cause an error or silently return if telemetry is already enabled.

func GetTensorElementDataType

func GetTensorElementDataType[T TensorData]() C.ONNXTensorElementDataType

Returns the ONNX enum value used to indicate TensorData type T.

func InitializeEnvironment

func InitializeEnvironment() error

Call this function to initialize the internal onnxruntime environment. If this doesn't return an error, the caller will be responsible for calling DestroyEnvironment to free the onnxruntime state when no longer needed.

func IsInitialized

func IsInitialized() bool

Returns false if the onnxruntime package is not initialized. Called internally by several functions, to avoid segfaulting if InitializeEnvironment hasn't been called yet.

func SetSharedLibraryPath

func SetSharedLibraryPath(path string)

Use this function to set the path to the "onnxruntime.so" or "onnxruntime.dll" function. By default, it will be set to "onnxruntime.so" on non-Windows systems, and "onnxruntime.dll" on Windows. Users wishing to specify a particular location of this library must call this function prior to calling onnxruntime.InitializeEnvironment().

Types

type BadShapeDimensionError

type BadShapeDimensionError struct {
	DimensionIndex int
	DimensionSize  int64
}

This type of error is returned when we attempt to validate a tensor that has a negative or 0 dimension.

func (*BadShapeDimensionError) Error

func (e *BadShapeDimensionError) Error() string

type DynamicSession

type DynamicSession[In TensorData, Out TensorData] struct {
	// contains filtered or unexported fields
}

Similar to Session, but does not require the specification of the input and output shapes at session creation time, and allows for input and output tensors to have different types. This allows for fully dynamic input to the onnx model.

func NewDynamicSession

func NewDynamicSession[in TensorData, out TensorData](onnxFilePath string, inputNames, outputNames []string) (*DynamicSession[in, out], error)

Same as NewSession, but for dynamic sessions.

func NewDynamicSessionWithONNXData

func NewDynamicSessionWithONNXData[in TensorData, out TensorData](onnxData []byte, inputNames, outputNames []string) (*DynamicSession[in, out], error)

Similar to NewSessionWithOnnxData, but for dynamic sessions.

func (*DynamicSession[_, _]) Destroy

func (s *DynamicSession[_, _]) Destroy() error

func (*DynamicSession[in, out]) Run

func (s *DynamicSession[in, out]) Run(inputs []*Tensor[in], outputs []*Tensor[out]) error

Runs the dynamic session. Differently from the Session object, this method requires the caller to provide the slice of input and output tensor pointer of the right type. The resulting output is stored in the output tensors, and it is the responsibility of the caller to destroy the tensors to free memory.

type FloatData

type FloatData interface {
	~float32 | ~float64
}

type IntData

type IntData interface {
	~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64
}

type Session

type Session[T TensorData] struct {
	// contains filtered or unexported fields
}

A wrapper around the OrtSession C struct. Requires the user to maintain all input and output tensors, and to use the same data type for input and output tensors.

func NewSession

func NewSession[T TensorData](onnxFilePath string, inputNames,
	outputNames []string, inputs, outputs []*Tensor[T]) (*Session[T], error)

Loads the ONNX network at the given path, and initializes a Session instance. If this returns successfully, the caller must call Destroy() on the returned session when it is no longer needed. We require the user to provide the input and output tensors and names at this point, in order to not need to re-allocate them every time Run() is called. The user instead can just update or access the input/output tensor data after calling Run(). The input and output tensors MUST outlive this session, and calling session.Destroy() will not destroy the input or output tensors.

func NewSessionWithONNXData

func NewSessionWithONNXData[T TensorData](onnxData []byte, inputNames,
	outputNames []string, inputs, outputs []*Tensor[T]) (*Session[T], error)

The same as NewSession, but takes a slice of bytes containing the .onnx network rather than a file path.

func (*Session[_]) Destroy

func (s *Session[_]) Destroy() error

func (*Session[T]) Run

func (s *Session[T]) Run() error

Runs the session, updating the contents of the output tensors on success.

type Shape

type Shape []int64

The Shape type holds the shape of the tensors used by the network input and outputs.

func NewShape

func NewShape(dimensions ...int64) Shape

Returns a Shape, with the given dimensions.

func (Shape) Clone

func (s Shape) Clone() Shape

Makes and returns a deep copy of the Shape.

func (Shape) Equals

func (s Shape) Equals(other Shape) bool

Returns true if both shapes match in every dimension.

func (Shape) FlattenedSize

func (s Shape) FlattenedSize() int64

Returns the total number of elements in a tensor with the given shape. Note that this may be an invalid value due to overflow or negative dimensions. If a shape comes from an untrusted source, it may be a good practice to call Validate() prior to trusting the FlattenedSize.

func (Shape) String

func (s Shape) String() string

func (Shape) Validate

func (s Shape) Validate() error

Returns a non-nil error if the shape has bad or zero dimensions. May return a ZeroShapeLengthError, a ShapeOverflowError, or an BadShapeDimensionError. In the future, this may return other types of errors if it others become necessary.

type Tensor

type Tensor[T TensorData] struct {
	// contains filtered or unexported fields
}

func NewEmptyTensor

func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error)

Creates a new empty tensor with the given shape. The shape provided to this function is copied, and is no longer needed after this function returns.

func NewTensor

func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error)

Creates a new tensor backed by an existing data slice. The shape provided to this function is copied, and is no longer needed after this function returns. If the data slice is longer than s.FlattenedSize(), then only the first portion of the data will be used.

func (*Tensor[T]) Clone

func (t *Tensor[T]) Clone() (*Tensor[T], error)

Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor returned by this function must be destroyed when no longer needed. The returned tensor will also no longer refer to the same underlying data; use GetData() to obtain the new underlying slice.

func (*Tensor[_]) Destroy

func (t *Tensor[_]) Destroy() error

Cleans up and frees the memory associated with this tensor.

func (*Tensor[T]) GetData

func (t *Tensor[T]) GetData() []T

Returns the slice containing the tensor's underlying data. The contents of the slice can be read or written to get or set the tensor's contents.

func (*Tensor[_]) GetShape

func (t *Tensor[_]) GetShape() Shape

Returns the shape of the tensor. The returned shape is only a copy; modifying this does *not* change the shape of the underlying tensor. (Modifying the tensor's shape can only be accomplished by Destroying and recreating the tensor with the same data.)

func (*Tensor[T]) SetData

func (t *Tensor[T]) SetData(indata []T) error

Returns the slice containing the tensor's underlying data. The contents of the slice can be read or written to get or set the tensor's contents.

type TensorData

type TensorData interface {
	FloatData | IntData
}

This is used as a type constraint for the generic Tensor type.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL