treelite

package module
v0.1.6 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 4, 2023 License: Apache-2.0 Imports: 13 Imported by: 1

README

go-treelite: TreeLite binding in Go

Go Report Card CodeQL Go Go.Dev reference

This binding currently works for treelite 3.4.0.

Prerequirements

install treelite

install treelite (see .devcontainer/Dockerfile)

1. install packages to install treelite

the following package is installed.

  • g++
  • cmake
  • make
2. install treelite
git clone https://github.com/dmlc/treelite.git -b 3.4.0
cd treelite \
mkdir build && cd build \
cmake .. \
sudo make install \
export LIBRARY_PATH=/usr/local/lib${LIBRARY_PATH:+:$LIBRARY_PATH}
export LD_LIBRARY_PATH=/usr/local/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}

Usage

Add this library to your project.

go get github.com/getumen/go-treelite

Documentation

The documentation is hosted at here. You can also take a look at example_test.go.

Documentation

Overview

Example
data, nRow, nCol := loadData()

dMatrix, err := treelite.CreateFromMat(data, nRow, nCol, float32(math.NaN()))
if err != nil {
	log.Fatal(err)
}
defer dMatrix.Close()

model, err := treelite.LoadXGBoostModel("testdata/xgboost.model")
if err != nil {
	log.Fatal(err)
}
defer model.Close()

annotator, err := treelite.NewAnnotator(model, dMatrix, 1, true)
if err != nil {
	log.Fatal(err)
}
defer annotator.Close()

err = annotator.Save("testdata/go-example-annotation.json")
if err != nil {
	log.Fatal(err)
}

compiler, err := treelite.NewCompiler(
	"ast_native",
	&treelite.CompilerParam{
		AnnotationPath: "testdata/go-example-annotation.json",
		Quantize:       true,
		ParallelComp:   runtime.NumCPU(),
		Verbose:        true,
	},
)
if err != nil {
	log.Fatal(err)
}
defer compiler.Close()

err = compiler.ExportSharedLib(
	model,
	"testdata/go_example_compiled_model",
	"gcc",
	nil,
)
if err != nil {
	log.Fatal(err)
}

predictor, err := treelite.NewPredictor(
	fmt.Sprintf("testdata/go_example_compiled_model.%s", treelite.GetSharedLibExtension()),
	runtime.NumCPU(),
)
if err != nil {
	log.Fatal(err)
}
defer predictor.Close()

scores, err := predictor.PredictBatch(dMatrix, true, false)
if err != nil {
	log.Fatal(err)
}
fmt.Printf("%+v\n", scores)
Output:

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func GetSharedLibExtension added in v0.1.5

func GetSharedLibExtension() string

Types

type Annotator

type Annotator struct {
	// contains filtered or unexported fields
}

Annotator annotate branches in decision trees see https://treelite.readthedocs.io/en/latest/tutorials/optimize.html#annotate-conditional-branches

func NewAnnotator

func NewAnnotator(
	model *Model,
	dmatrix *DMatrix,
	numThread int,
	verbose bool,
) (*Annotator, error)

NewAnnotator creates an annotator of the given model and DMatrix

func (Annotator) Close

func (a Annotator) Close() error

Close frees internally allocated memory

func (Annotator) Save

func (a Annotator) Save(path string) error

Save saves annotation result to the given path

type BooleanInC

type BooleanInC bool

BooleanInC generates int representation in json

func (*BooleanInC) MarshalJSON

func (b *BooleanInC) MarshalJSON() ([]byte, error)

MarshalJSON returns int representation of bool

type Compiler

type Compiler struct {
	// contains filtered or unexported fields
}

Compiler produces optimize prediction subroutine (in C) from a given decision tree ensemble.

func NewCompiler

func NewCompiler(
	name string,
	param *CompilerParam,
) (*Compiler, error)

NewCompiler creates new compiler by the given setting

func (*Compiler) Close

func (c *Compiler) Close() error

Close frees internally allocated memory

func (*Compiler) ExportSharedLib

func (c *Compiler) ExportSharedLib(
	model *Model,
	destPath string,
	toolChain string,
	compileOptions []string,
) error

ExportSharedLib exports shared object of optimized model compileOptions is compiler option e.g. -g

func (*Compiler) GenerateCode

func (c *Compiler) GenerateCode(model *Model, destPath string) error

GenerateCode generate C source code from the given model

type CompilerParam

type CompilerParam struct {
	// name of model annotation file.
	// Use the class treelite.Annotator to generate this file.
	AnnotationPath string `json:"annotate_in,omitempty"`
	// whether to quantize threshold points
	Quantize BooleanInC `json:"quantize,omitempty"`
	// option to enable parallel compilation;
	// if set to nonzero, the trees will be evely distributed into [parallel_comp] files.
	// Set this option to improve compilation time and reduce memory consumption during compilation.
	ParallelComp int `json:"parallel_comp,omitempty"`
	// produce extra messages
	Verbose BooleanInC `json:"verbose,omitempty"`
	// native lib name (without extension)
	NativeLibName string `json:"native_lib_name,omitempty"`
	// parameter for folding rarely visited subtrees (no if/else blocks);
	// all nodes whose data counts are lower than that of the root node of the decision tree by [code_folding_req] will be folded.
	// To diable folding, set to +inf. If hessian sums are available, they will be used as proxies of data counts.
	CodeFoldingReq float64 `json:"code_folding_req,omitempty"`
	// Only applicable when compiler is set to failsafe.
	// If set to a positive value, the fail-safe compiler will not emit large constant arrays to the C code.
	// Instead, the arrays will be emitted as an ELF binary (Linux only).
	// For large arrays, it is much faster to directly dump ELF binaries than to pass them to a C compiler.
	DumpArrayAsELF BooleanInC `json:"dump_array_as_elf,omitempty"`
}

CompilerParam is a compiler setting

type DMatrix

type DMatrix struct {
	// contains filtered or unexported fields
}

DMatrix is treelite DMatrix

func CreateFromCSR

func CreateFromCSR(
	header []uint64,
	indices []uint32,
	data []float32,
	nrow, ncol int,
) (*DMatrix, error)

CreateFromCSR creates a sparse DMatrix from the given data

func CreateFromMat added in v0.1.3

func CreateFromMat(
	data []float32,
	nrow, ncol int,
	missing float32,
) (*DMatrix, error)

CreateFromMat creates a dense DMatrix from the given data

func (*DMatrix) Close

func (d *DMatrix) Close() error

Close frees internally allocated memory

func (DMatrix) Col

func (d DMatrix) Col() int

Col returns the number of columns

func (DMatrix) Element

func (d DMatrix) Element() int

Element returns the number of internal elements

func (DMatrix) Row

func (d DMatrix) Row() int

Row returns the number of rows

type Model

type Model struct {
	// contains filtered or unexported fields
}

Model is a tree model

func LoadLightGBMModel

func LoadLightGBMModel(modelPath string) (*Model, error)

LoadLightGBMModel loads a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision tree ensemble.

func LoadLightGBMModelFromString

func LoadLightGBMModelFromString(modelString string) (*Model, error)

LoadLightGBMModelFromString Loads a LightGBM model from a string. The string should be created with the model_to_string() method in LightGBM.

func LoadXGBoostJSON

func LoadXGBoostJSON(modelPath string) (*Model, error)

LoadXGBoostJSON loads a json model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree ensemble.

func LoadXGBoostJSONString

func LoadXGBoostJSONString(modelJSON string) (*Model, error)

LoadXGBoostJSONString loads a model stored as JSON stringby XGBoost (dmlc/xgboost). The model json must contain a decision tree ensemble.

func LoadXGBoostModel

func LoadXGBoostModel(modelPath string) (*Model, error)

LoadXGBoostModel loads a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree ensemble.

func LoadXGBoostModelFromMemoryBuffer

func LoadXGBoostModelFromMemoryBuffer(reader io.Reader) (*Model, error)

LoadXGBoostModelFromMemoryBuffer loads XGBoost model from the given reader this method once loads all model to memory

func (*Model) Close

func (m *Model) Close() error

Close frees internally allocated model

type Predictor

type Predictor struct {
	// contains filtered or unexported fields
}

Predictor is compiled prediction subroutines from shared libraries

func NewPredictor

func NewPredictor(libraryPath string, NumWorkerThread int) (*Predictor, error)

NewPredictor load prediction code into memory. This function assumes that the prediction code has been already compiled into a dynamic shared library object (.so/.dll/.dylib).

func (*Predictor) Close

func (p *Predictor) Close() error

Close frees internally allocated memory

func (Predictor) DataType

func (p Predictor) DataType() string

DataType returns data type of predictor only float32 is supported

func (Predictor) GlobalBias added in v0.1.6

func (p Predictor) GlobalBias() float32

GlobalBias returns global bias which adjusting predicted margin scores.

func (Predictor) NumClass

func (p Predictor) NumClass() int

NumClass returns the number of classes of the given model

func (Predictor) NumFeature added in v0.1.6

func (p Predictor) NumFeature() int

NumFeature returns the number of feature of the given model

func (Predictor) PredTransform added in v0.1.6

func (p Predictor) PredTransform() string

PredTransform returns name of post prediction transformation used to train the loaded model.

func (Predictor) PredictBatch

func (p Predictor) PredictBatch(
	dmat *DMatrix,
	verbose bool,
	predMargin bool,
) ([]float32, error)

PredictBatch make predictions on a batch of data rows (synchronously). This function internally divides the workload among all worker threads. the length of returned scores is #row * #class if your #row is 4 and #class is 3, then the length of scores is 12

for rowID := 0; rowID < dMatrix.Row(); rowID++ {
  for classID := 0; classID < predictor.NumClass(); classID++ {
	   value = scores[rowID*predictor.NumClass()+classID]
  }
}

func (Predictor) RatioC added in v0.1.6

func (p Predictor) RatioC() float32

RatioC returns c value of exponential standard ratio transformation used to train the loaded model.

func (Predictor) SigmoidAlpha added in v0.1.6

func (p Predictor) SigmoidAlpha() float32

SigmoidAlpha returns alpha value of sigmoid transformation used to train the loaded model.

func (Predictor) ThresholdType added in v0.1.6

func (p Predictor) ThresholdType() string

ThresholdType returns threshold type.

type Recipe added in v0.1.1

type Recipe struct {
	Target  string `json:"target"`
	Sources []struct {
		Name   string `json:"name"`
		Length int    `json:"length"`
	} `json:"sources"`
}

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL