Documentation
¶
Overview ¶
Package xgboost is a pure Golang implementation of loading DMLC XGBoost json model generated from dump_model python API. This package supports binary, multiclass and regression inference. Note that this package is just for inference purpose only, for training part please reference to https://github.com/dmlc/xgboost.
Training model ¶
In order to have a json encoded model file, we need to train the model via Python first:
iris_xgboost.py:
import xgboost as xgb from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.datasets import dump_svmlight_file import numpy as np X, y = datasets.load_iris(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) dtrain = xgb.DMatrix(X_train, label=y_train) param = {'max_depth': 4, 'eta': 1, 'objective': 'multi:softmax', 'nthread': 4, 'eval_metric': 'auc', 'num_class': 3} num_round = 10 bst = xgb.train(param, dtrain, num_round) y_pred = bst.predict(xgb.DMatrix(X_test)) clf = xgb.XGBClassifier(max_depth=4, objective='multi:softprob', n_estimators=10, num_classes=3) clf.fit(X_train, y_train) y_pred_proba = clf.predict_proba(X_test) np.savetxt('../data/iris_xgboost_true_prediction.txt', y_pred, delimiter='\t') np.savetxt('../data/iris_xgboost_true_prediction_proba.txt', y_pred_proba, delimiter='\t') dump_svmlight_file(X_test, y_test, '../data/iris_test.libsvm') bst.dump_model('../data/iris_xgboost_dump.json', dump_format='json')
Here is how to load the model exported from the above script:
package main import ( "fmt" "github.com/Elvenson/xgboost-go/activation" "github.com/Elvenson/xgboost-go/mat" "github.com/Elvenson/xgboost-go/models" ) func main() { ensemble, err := models.LoadXGBoostFromJSON("your model path", "", 1, 4, &activation.Logistic{}) if err != nil { panic(err) } input, err := mat.ReadLibsvmFileToSparseMatrix("your libsvm input path") if err != nil { panic(err) } predictions, err := ensemble.PredictProba(input) if err != nil { panic(err) } fmt.Printf("%+v\n", predictions) }
For more information, please take a look at xgbensemble_test.go
Index ¶
- func LoadXGBoost(xgbEnsembleJSON []*xgboostJSON, featuresMapPath string, numClasses int, ...) (*inference.Ensemble, error)
- func LoadXGBoostFromJSON(modelPath, featuresMapPath string, numClasses int, maxDepth int, ...) (*inference.Ensemble, error)
- func LoadXGBoostFromJSONBytes(jsonBytes []byte, featuresMapPath string, numClasses int, maxDepth int, ...) (*inference.Ensemble, error)
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func LoadXGBoost ¶ added in v0.1.3
func LoadXGBoost( xgbEnsembleJSON []*xgboostJSON, featuresMapPath string, numClasses int, maxDepth int, activation activation.Activation) (*inference.Ensemble, error)
func LoadXGBoostFromJSON ¶
func LoadXGBoostFromJSON( modelPath, featuresMapPath string, numClasses int, maxDepth int, activation activation.Activation) (*inference.Ensemble, error)
LoadXGBoostFromJSON loads xgboost model from json file.
func LoadXGBoostFromJSONBytes ¶ added in v0.1.3
func LoadXGBoostFromJSONBytes( jsonBytes []byte, featuresMapPath string, numClasses int, maxDepth int, activation activation.Activation) (*inference.Ensemble, error)
Types ¶
This section is empty.
Click to show internal directories.
Click to hide internal directories.