package module
Version: v0.0.0-...-5e92bb2 Latest Latest

This package is not in the latest version of its module.

Go to latest
Published: Feb 25, 2019 License: Apache-2.0 Imports: 7 Imported by: 0



Build Status GoDoc


LibTorch (PyTorch) bindings for Golang. Library is first and foremost designed for running inference against serialized models exported from Python version of PyTorch. Library can also be used to compile TorchScript applications directly from Go.


$ go get


go-torch depends on the LibTorch shared library to be available. For more information refer to The is also an example Dockerfile which is used for executing tests for the library.

import (
Creating Tensors

Supported scalar types:

  • torch.Byte uint8
  • torch.Char int8
  • torch.Int int32
  • torch.Long int64
  • torch.Float float32
  • torch.Double float64

matrix := []float32{
tensor, _ := torch.NewTensor(matrix)
tensor.Shape() // [2, 3]
tensor.DType() // torch.Float
Using serialized PyTorch models

For instructions on how to export models for PyTorch refer to the PyTorch documentation

// Load model
module, _ := torch.LoadJITModule("")

// Create an input tensor
inputTensor, _ := torch.NewTensor([][]float32{
    []float32{1, 2, 3},

// Forward propagation
res, _ := module.Forward(inputTensor)

Using TorchScript

TorchScript documentation

Currently supported input and output types

  • Tensor
  • Tuple (of Tensor and/or nested Tuples)
sumScript = `
def sum(a, b):
    return a + b

// Compile TorchScript
module, _ := torch.CompileTorchScript(sumScript)

// Create inputs
a, _ := torch.NewTensor([]float32{1})
b, _ := torch.NewTensor([]float32{2})

res, _ := module.RunMethod("sum", a, b)
fmt.Printf("[1] + [2] = %+v\n", res.(*torch.Tensor).Value())
// output: [1] + [2] = [3]


Lots of the functionality related to converting Golang types to PyTorch Tensors are a shameless copy on what Google is doing with their Go Tensorflow bindings. Therefore big part of the credit definetely goes to The TensorFlow Authors.


See here





This section is empty.


This section is empty.


func PrintTensors

func PrintTensors(inputs ...*Tensor)

PrintTensors prints tensors contents


type DType

type DType C.Torch_DataType

DType tensor scalar data type

const (
	// Byte byte tensors (go type uint8)
	Byte DType = C.Torch_Byte
	// Char char tensor (go type int8)
	Char DType = C.Torch_Char
	// Int int tensor (go type int32)
	Int DType = C.Torch_Int
	// Long long tensor (go type int64)
	Long DType = C.Torch_Long
	// Float tensor (go type float32)
	Float DType = C.Torch_Float
	// Double tensor  (go type float64)
	Double DType = C.Torch_Double

type Error

type Error struct {
	// contains filtered or unexported fields

Error errors returned by torch functions

func (*Error) Error

func (te *Error) Error() string

type JITModule

type JITModule struct {
	// contains filtered or unexported fields

JITModule is a jit compiled PyTorch module

func CompileTorchScript

func CompileTorchScript(torchScript string) (*JITModule, error)

CompileTorchScript compiles TorchScript and returns a *JITModule

module, _ := torch.CompileTorchScript(`
		def sum(a, b):
			return a + b

a, _ := torch.NewTensor([]float32{1})
b, _ := torch.NewTensor([]float32{2})

result, _ := module.RunMethod("sum", a, b)
fmt.Printf("[1] + [2] = %+v\n", result.(*torch.Tensor).Value())

[1] + [2] = [3]

func LoadJITModule

func LoadJITModule(path string) (*JITModule, error)

LoadJITModule loads module from file

func (*JITModule) Forward

func (m *JITModule) Forward(inputs ...interface{}) (interface{}, error)

Forward exectures forward method of the module (forward propagation)

func (*JITModule) GetMethod

func (m *JITModule) GetMethod(method string) (*JITModuleMethod, error)

GetMethod returns a method from a JITModule

func (*JITModule) GetMethodNames

func (m *JITModule) GetMethodNames() []string

GetMethodNames returns all method names from the module

func (*JITModule) RunMethod

func (m *JITModule) RunMethod(method string, inputs ...interface{}) (interface{}, error)

RunMethod executes given method with tensors or tuples as input

func (*JITModule) Save

func (m *JITModule) Save(path string) error

Save saves Module to given path

type JITModuleMethod

type JITModuleMethod struct {
	Module *JITModule
	Name   string
	// contains filtered or unexported fields

JITModuleMethod is single method from a JITModule

func (*JITModuleMethod) Arguments

func (m *JITModuleMethod) Arguments() []JITModuleMethodArgument

Arguments returns method arguments for the method schema

func (*JITModuleMethod) Returns

Returns returns method return type information for the method schema

func (*JITModuleMethod) Run

func (m *JITModuleMethod) Run(inputs ...interface{}) (interface{}, error)

Run executes given method with tensors as input

type JITModuleMethodArgument

type JITModuleMethodArgument struct {
	Name string
	Type string

JITModuleMethodArgument contains information of a single method argument

type Tensor

type Tensor struct {
	// contains filtered or unexported fields

Tensor holds a multi-dimensional array of elements of a single data type.

func NewTensor

func NewTensor(value interface{}) (*Tensor, error)

NewTensor converts from a Go value to a Tensor. Valid values are scalars, slices, and arrays. Every element of a slice must have the same length so that the resulting Tensor has a valid shape.

func NewTensorWithShape

func NewTensorWithShape(value interface{}, shape []int64, dt DType) (*Tensor, error)

NewTensorWithShape converts a single dimensional Go array or slice into a Tensor with given shape

func (*Tensor) DType

func (t *Tensor) DType() DType

DType returns tensors datatype

func (*Tensor) Shape

func (t *Tensor) Shape() []int64

Shape returns tensors shape

func (*Tensor) Value

func (t *Tensor) Value() interface{}

Value returns tensors value as a go type

type Tuple

type Tuple []interface{}

Tuple a tuple type

func NewTuple

func NewTuple(vals ...interface{}) (Tuple, error)

NewTuple returns a new tuple for given values (go types, torch.Tensor, torch.Tuple)

func (Tuple) Get

func (t Tuple) Get(index int) interface{}

Get returns a type in specific tuple index (otherwise returns nil)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
t or T : Toggle theme light dark auto
y or Y : Canonical URL