onnxruntime_go

package module
v0.16.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 15, 2023 License: MIT Imports: 8 Imported by: 2

README

Cross-Platform onnxruntime Wrapper for Go

About

This library seeks to provide an interface for loading and executing neural networks from Go(lang) code, while remaining as simple to use as possible.

The onnxruntime library provides a way to load and execute ONNX-format neural networks, though the library primarily supports C and C++ APIs. Several efforts exist to have written Go(lang) wrappers for the onnxruntime library, but as far as I can tell, none of these existing Go wrappers support Windows. This is due to the fact that Microsoft's onnxruntime library assumes the user will be using the MSVC compiler on Windows systems, while CGo on Windows requires using Mingw.

This wrapper works around the issues by manually loading the onnxruntime shared library, removing any dependency on the onnxruntime source code beyond the header files. Naturally, this approach works equally well on non-Windows systems.

Additionally, this library uses Go's recent addition of generics to support multiple Tensor data types; see the NewTensor or NewEmptyTensor functions.

Requirements

To use this library, you'll need a version of Go with cgo support. If you are not using an amd64 version of Windows or Linux (or if you want to provide your own library for some other reason), you simply need to provide the correct path to the shared library when initializing the wrapper. This is seen in the first few lines of the following example.

Example Usage

The following example illustrates how this library can be used to load and run an ONNX network taking a single input tensor and producing a single output tensor, both of which contain 32-bit floating point values. Note that error handling is omitted; each of the functions returns an err value, which will be non-nil in the case of failure.

import (
    "fmt"
    ort "github.com/yalue/onnxruntime_go"
    "os"
)

func main() {
    // This line may be optional, by default the library will try to load
    // "onnxruntime.dll" on Windows, and "onnxruntime.so" on any other system.
    ort.SetSharedLibraryPath("path/to/onnxruntime.so")

    err := ort.InitializeEnvironment()
    defer ort.DestroyEnvironment()

    // To make it easier to work with the C API, this library requires the user
    // to create all input and output tensors prior to creating the session.
    inputData := []float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}
    inputShape := ort.NewShape(2, 5)
    inputTensor, err := ort.NewTensor(inputShape, inputData)
    defer inputTensor.Destroy()
    // This hypothetical network maps a 2x5 input -> 2x3x4 output.
    outputShape := ort.NewShape(2, 3, 4)
    outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
    defer outputTensor.Destroy()

    session, err := ort.NewSession[float32]("path/to/network.onnx",
        []string{"Input 1 Name"}, []string{"Output 1 Name"},
        []*Tensor[float32]{inputTensor}, []*Tensor[float32]{outputTensor})
    defer session.Destroy()

    // Calling Run() will run the network, reading the current contents of the
    // input tensors and modifying the contents of the output tensors. Simply
    // modify the input tensor's data (available via inputTensor.GetData())
    // before calling Run().
    err = session.Run()

    outputData := outputTensor.GetData()

    // ...
}

The full documentation can be found at pkg.go.dev.

Documentation

Overview

This library wraps the C "onnxruntime" library maintained at https://github.com/microsoft/onnxruntime. It seeks to provide as simple an interface as possible to load and run ONNX-format neural networks from Go code.

This library wraps the C "onnxruntime" library maintained at https://github.com/microsoft/onnxruntime. It seeks to provide as simple an interface as possible to load and run ONNX-format neural networks from Go code.

This library wraps the C "onnxruntime" library maintained at https://github.com/microsoft/onnxruntime. It seeks to provide as simple an interface as possible to load and run ONNX-format neural networks from Go code.

Index

Constants

This section is empty.

Variables

View Source
var NotInitializedError error = fmt.Errorf("InitializeRuntime() has either " +
	"not yet been called, or did not return successfully")

Functions

func CheckInputsOutputs added in v0.15.0

func CheckInputsOutputs(inputs []*TensorWithType, outputs []*TensorWithType, inputNames, outputNames []string) error

func ConvertNames added in v0.15.0

func ConvertNames(names []string) []*C.char

func ConvertTensors added in v0.15.0

func ConvertTensors(tensors []*TensorWithType) []*C.OrtValue

func DestroyEnvironment

func DestroyEnvironment() error

Call this function to cleanup the internal onnxruntime environment when it is no longer needed.

func GetTensorElementDataType

func GetTensorElementDataType[T TensorData]() C.ONNXTensorElementDataType

Returns the ONNX enum value used to indicate TensorData type T.

func InitializeEnvironment

func InitializeEnvironment() error

Call this function to initialize the internal onnxruntime environment. If this doesn't return an error, the caller will be responsible for calling DestroyEnvironment to free the onnxruntime state when no longer needed.

func IsInitialized

func IsInitialized() bool

Returns false if the onnxruntime package is not initialized. Called internally by several functions, to avoid segfaulting if InitializeEnvironment hasn't been called yet.

func NewEmptyTensorWithType

func NewEmptyTensorWithType(tensorType string, s Shape) (interface{}, error)

func RandomSelect added in v0.12.1

func RandomSelect(probabilities []float64) int

func Reshape added in v0.14.0

func Reshape(data []float64, shape []int64) [][]float64

func SetSharedLibraryPath

func SetSharedLibraryPath(path string)

Use this function to set the path to the "onnxruntime.so" or "onnxruntime.dll" function. By default, it will be set to "onnxruntime.so" on non-Windows systems, and "onnxruntime.dll" on Windows. Users wishing to specify a particular location of this library must call this function prior to calling onnxruntime.InitializeEnvironment().

func Softmax added in v0.12.1

func Softmax(x []float64) []float64

function to calculate softmax

func TakeTopP added in v0.12.1

func TakeTopP(logs []float64, topP float64) ([]float64, []int)

Types

type FloatData

type FloatData interface {
	~float32 | ~float64
}

type IntData

type IntData interface {
	~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64
}

type Log added in v0.12.1

type Log struct {
	Value float64
	Index int
}

type RunV3GenOptions added in v0.12.3

type RunV3GenOptions struct {
	MaxTokens           int
	MaxNewTokens        int
	TopP                float64
	Temperature         float64
	EOSTokenID          int
	ReplacementIndexes  []int
	AttentionMaskIndex  int
	UseCacheBranchIndex int
	SeparateDecoder     bool
}

type SessionV3 added in v0.12.3

type SessionV3 struct {
	// contains filtered or unexported fields
}

func NewSessionV3 added in v0.12.3

func NewSessionV3(path string, opts ...string) (*SessionV3, error)

func (*SessionV3) Destroy added in v0.12.3

func (s *SessionV3) Destroy() error

func (*SessionV3) GetInputNames added in v0.12.8

func (s *SessionV3) GetInputNames() []string

func (*SessionV3) GetInputShapes added in v0.12.5

func (s *SessionV3) GetInputShapes() (shapeTypes []ShapeType)

func (*SessionV3) GetInputTypes added in v0.12.5

func (s *SessionV3) GetInputTypes() []string

func (*SessionV3) GetOutputNames added in v0.12.8

func (s *SessionV3) GetOutputNames() []string

func (*SessionV3) GetOutputShapes added in v0.12.5

func (s *SessionV3) GetOutputShapes() (shapeTypes []ShapeType)

func (*SessionV3) GetOutputTypes added in v0.12.5

func (s *SessionV3) GetOutputTypes() []string

func (*SessionV3) Run added in v0.12.3

func (s *SessionV3) Run(inputs []*TensorWithType) (outputs []*TensorWithType, err error)

func (*SessionV3) RunDecoder added in v0.15.0

func (s *SessionV3) RunDecoder(inputs []*TensorWithType, opt *RunV3GenOptions) (outTokenIds []int64, err error)

func (*SessionV3) RunGen added in v0.12.3

func (s *SessionV3) RunGen(inputs []*TensorWithType, opt *RunV3GenOptions) (outputs []*TensorWithType, err error)

func (*SessionV3) RunMergedDecoder added in v0.15.0

func (s *SessionV3) RunMergedDecoder(inputs []*TensorWithType, opt *RunV3GenOptions) (outTokenIds []int64, err error)

type Shape

type Shape []int64

The Shape type holds the shape of the tensors used by the network input and outputs.

func NewShape

func NewShape(dimensions ...int64) Shape

Returns a Shape, with the given dimensions.

func (Shape) Clone

func (s Shape) Clone() Shape

Makes and returns a deep copy of the Shape.

func (Shape) FlattenedSize

func (s Shape) FlattenedSize() int64

Returns the total number of elements in a tensor with the given shape.

func (Shape) String

func (s Shape) String() string

type ShapeType added in v0.12.5

type ShapeType struct {
	Shape         []int64
	SymbolicShape []string
	Type          string
}

type Tensor

type Tensor[T TensorData] struct {
	// contains filtered or unexported fields
}

func GetTensor added in v0.10.0

func GetTensor[T TensorData](value *C.OrtValue, elementCount int64, data []T, shape []int64) (*Tensor[T], error)

func NewEmptyTensor

func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error)

Creates a new empty tensor with the given shape. The shape provided to this function is copied, and is no longer needed after this function returns.

func NewTensor

func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error)

Creates a new tensor backed by an existing data slice. The shape provided to this function is copied, and is no longer needed after this function returns. If the data slice is longer than s.FlattenedSize(), then only the first portion of the data will be used.

func (*Tensor[T]) Clone

func (t *Tensor[T]) Clone() (*Tensor[T], error)

Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor returned by this function must be destroyed when no longer needed.

func (*Tensor[_]) Destroy

func (t *Tensor[_]) Destroy() error

Cleans up and frees the memory associated with this tensor.

func (*Tensor[T]) GetData

func (t *Tensor[T]) GetData() []T

Returns the slice containing the tensor's underlying data. The contents of the slice can be read or written to get or set the tensor's contents.

func (*Tensor[_]) GetShape

func (t *Tensor[_]) GetShape() Shape

Returns the shape of the tensor. The returned shape is only a copy; modifying this does *not* change the shape of the underlying tensor. (Modifying the tensor's shape can only be accomplished by Destroying and recreating the tensor with the same data.)

type TensorData

type TensorData interface {
	FloatData | IntData | ~bool
}

This is used as a type constraint for the generic Tensor type.

type TensorWithType

type TensorWithType struct {
	TensorType string
	Tensor     interface{}
}

func (*TensorWithType) Destroy

func (t *TensorWithType) Destroy() error

func (*TensorWithType) GetData

func (t *TensorWithType) GetData() interface{}

func (*TensorWithType) GetShape added in v0.10.1

func (t *TensorWithType) GetShape() []int64

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL