Documentation ¶
Index ¶
- type Callback
- type Evaluator
- type Learner
- type LiveEvaluator
- type RpcHandler
- type TrainResponse
- type Trainer
- func (t *Trainer) DeleteLearner(req *pbCom.StopTaskRequest) error
- func (t *Trainer) NewLearner(req *pbCom.StartTaskRequest) error
- func (t *Trainer) SavePredictAndEvaluatResult(result *pbCom.TrainTaskResult)
- func (t *Trainer) SaveResult(result *pbCom.TrainTaskResult)
- func (t *Trainer) Train(req *pb.TrainRequest, resC chan *TrainResponse)
- func (t *Trainer) Validate(req *pb.ValidateRequest, resC chan *TrainResponse)
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type Callback ¶
type Callback interface { //SaveModel to persist a model SaveModel(*pbCom.TrainTaskResult) error // StartTask starts a specific task of training or prediction StartTask(*pbCom.StartTaskRequest) error //StopTask to stop a training task // You'd better use it asynchronously to avoid deadlock StopTask(*pbCom.StopTaskRequest) error // Train to train out a model Train(*pb.TrainRequest) (*pb.TrainResponse, error) }
Callback contains some methods would be called when finish training, such as to save the trained models and to stop a training task. On the other hand, it also contains some other methods would be called during the evaluation phase, such as to start a specific task of training or prediction and to train out a model. It will be set into Trainer instance in initialization phase.
type Evaluator ¶
type Evaluator interface { // Start starts model evaluation, segment the training set according to a certain strategy (cross validation, proportional random division), // then starts the training-validation process. // fileRows is returned by psi.IntersectParts after sample alignment. Start(fileRows [][]string) error // Stop deletes all the leaners created by Evaluator as well as other objects Stop() // SaveModel collects the results of the training in the evaluation phase, that is, the model. // If the model is successfully trained, // it will trigger the local creation of a Model instance for validation. SaveModel(*pbCom.TrainTaskResult) error // SavePredictOut collects the prediction results in the evaluation phase. // If the prediction result is obtained, it will check how many prediction results have been obtained so far, // and determine whether to start calculating the average scores for each metric. SavePredictOut(*pbCom.PredictTaskResult) error }
Evaluator performs model evaluation, supports cross-validation, LOO, validation by proportional random division. The basic steps of evaluation:
Divide the dataset in some way Train the model Validate Calculate the evaluation metric scores with prediction result obtained on the validation set Calculate the average scores for each metric
type Learner ¶
type Learner interface { // Advance does calculation with local data and communicates with other nodes in cluster to train a model step by step // payload could be resolved by Learner defined by specific algorithm // We'd better call the method asynchronously avoid blocking the main go-routine Advance(payload []byte) (*pb.TrainResponse, error) }
Learner is assigned with a specific algorithm and data used for training a model,
and participates in the multi-parts-calculation during training process
type LiveEvaluator ¶
type LiveEvaluator interface { // Trigger triggers model evaluation. // The parameter contains two types of messages. // One is to set the learner for evaluation with training set and start it. // The other is to drive the learner to continue training. When the conditions are met(reaching pause round), // stop training and instantiate the model for validation. Trigger(*pb.LiveEvaluationTriggerMsg) error // Stop deletes all the leaners created by LiveEvaluator as well as other objects Stop() // SaveModel collects the results of the training in the evaluation phase, // that is, the model, for LiveEvaluation of Model. // If the model is successfully trained, // it will trigger the local creation of a Model instance for validation. SaveModel(*pbCom.TrainTaskResult) error // SavePredictOut collects the prediction results in the evaluation phase. // If the prediction result is obtained, it will start calculating metric scores, // then report the results to visualization system. SavePredictOut(*pbCom.PredictTaskResult) error }
LiveEvaluator performs staged evaluation during training. The basic steps of LiveEvaluator:
Divide the dataset in the way of proportional random division. Initiate a learner for evaluation with training part. Train the model, and pause training when the pause round is reached, and instantiate the staged model for validation, then, calculate the evaluation metric scores with prediction result obtained on the validation set. Repeat Train-Pause-validate until the stop signal is received.
type RpcHandler ¶
type RpcHandler interface { StepTrain(req *pb.TrainRequest, peerName string) (*pb.TrainResponse, error) // StepTrainWithRetry sends training message to remote mpc-node // retries 2 times at most // inteSec indicates the interval between retry requests, in seconds StepTrainWithRetry(req *pb.TrainRequest, peerName string, times int, inteSec int64) (*pb.TrainResponse, error) }
RpcHandler performs remote procedure calls to remote cluster nodes. set into Trainer instance in initialization phase
type TrainResponse ¶
type TrainResponse struct { Resp *pb.TrainResponse Err error }
type Trainer ¶
type Trainer struct {
// contains filtered or unexported fields
}
Trainer manages Learners, such as to create or to delete a learner dispatches requests to different Learners by taskId, keeps the number of Learners in the proper range in order to avoid high memory usage
func NewTrainer ¶
func NewTrainer(address string, rh RpcHandler, cb Callback, learnerLimit int) *Trainer
NewTrainer creates a Trainer instance, address indicates local mpc-node address learnerLimit indicates the upper limit of the number of Learners rh indicates the handler for rpc request sending cb indicates the callback methods called when finish training
func (*Trainer) DeleteLearner ¶
func (t *Trainer) DeleteLearner(req *pbCom.StopTaskRequest) error
DeleteLearner deletes a task from Memory Storage
func (*Trainer) NewLearner ¶
func (t *Trainer) NewLearner(req *pbCom.StartTaskRequest) error
NewLearner creates a Learner instance related to TaskId and stores it into Memory Storage keeps the number of Learners in the proper range in order to avoid high memory usage
func (*Trainer) SavePredictAndEvaluatResult ¶
func (t *Trainer) SavePredictAndEvaluatResult(result *pbCom.TrainTaskResult)
SavePredictAndEvaluatResult saves the training result and evaluation result for a Learner and stops related task. Called only by Evaluator.
func (*Trainer) SaveResult ¶
func (t *Trainer) SaveResult(result *pbCom.TrainTaskResult)
SaveResult saves the training result (failed status or successful status) for a Learner and stops related task. Analyze the TaskID to determine whether the training task is a common task from user or a task from Evaluator. If the former, and user didn't ask for evaluation, persist the prediction results locally,
otherwise call Evaluator.Start() to start evaluation process.
If the latter, call Evaluator.SaveModel().
func (*Trainer) Train ¶
func (t *Trainer) Train(req *pb.TrainRequest, resC chan *TrainResponse)
Train dispatches requests to different Learners by taskId resC returns the result, and couldn't be set with nil
func (*Trainer) Validate ¶
func (t *Trainer) Validate(req *pb.ValidateRequest, resC chan *TrainResponse)
Validate saves the prediction results to the Evaluator or LiveEvaluator, then trigger the subsequent verification process.