zoo.pipeline.estimator package

Submodules

zoo.pipeline.estimator.estimator module

class zoo.pipeline.estimator.estimator.Estimator(model, optim_methods=None, model_dir=None, jvalue=None, bigdl_type='float')[source]

Bases: bigdl.util.common.JavaValue

Estimator class for training and evaluation BigDL models.

Estimator wraps a model, and provide an uniform training, evaluation or prediction operation on both local host and distributed spark environment.

clear_gradient_clipping()[source]

Clear gradient clipping parameters. In this case, gradient clipping will not be applied. In order to take effect, it needs to be called before fit. :return:

evaluate(validation_set, validation_method, batch_size=32)[source]

Evaluate the model on the validationSet with the validationMethods. :param validation_set: validation FeatureSet, a FeatureSet[Sample[T]] :param validation_method: validation methods :param batch_size: batch size :return: validation results

evaluate_imagefeature(validation_set, validation_method, batch_size=32)[source]

Evaluate the model on the validationSet with the validationMethods. :param validation_set: validation FeatureSet, a FeatureSet[Sample[T]] :param validation_method: validation methods :param batch_size: batch size :return: validation results

evaluate_minibatch(validation_set, validation_method)[source]

Evaluate the model on the validationSet with the validationMethods. :param validation_set: validation FeatureSet, a FeatureSet[MiniBatch[T]] :param validation_method: validation methods :return: validation results

get_train_summary(tag=None)[source]

Get the scalar from model train summary Return 2-D array like object which could be converted by nd.array() # Arguments tag: The string variable represents the scalar wanted

get_validation_summary(tag=None)[source]

Get the scalar from model validation summary Return 2-D array like object which could be converted by np.array()

Note: The metric and tag may not be consistent Please look up following form to pass tag parameter Left side is your metric during compile Right side is the tag you should pass ‘Accuracy’ | ‘Top1Accuracy’ ‘BinaryAccuracy’ | ‘Top1Accuracy’ ‘CategoricalAccuracy’ | ‘Top1Accuracy’ ‘SparseCategoricalAccuracy’ | ‘Top1Accuracy’ ‘AUC’ | ‘AucScore’ ‘HitRatio’ | ‘HitRate@k’ (k is Top-k) ‘Loss’ | ‘Loss’ ‘MAE’ | ‘MAE’ ‘NDCG’ | ‘NDCG’ ‘TFValidationMethod’ | ‘${name + ” ” + valMethod.toString()}’ ‘Top5Accuracy’ | ‘Top5Accuracy’ ‘TreeNNAccuracy’ | ‘TreeNNAccuracy()’ ‘MeanAveragePrecision’ | ‘MAP@k’ (k is Top-k) (BigDL) ‘MeanAveragePrecision’ | ‘PascalMeanAveragePrecision’ (Zoo) ‘StatelessMetric’ | ‘${name}’ # Arguments tag: The string variable represents the scalar wanted

set_constant_gradient_clipping(min, max)[source]

Set constant gradient clipping during the training process. In order to take effect, it needs to be called before fit. :param min: The minimum value to clip by. :param max: The maximum value to clip by. :return:

set_l2_norm_gradient_clipping(clip_norm)[source]

Clip gradient to a maximum L2-Norm during the training process. In order to take effect, it needs to be called before fit. :param clip_norm: Gradient L2-Norm threshold. :return:

set_tensorboad(log_dir, app_name)[source]

Set summary information during the training process for visualization purposes. Saved summary can be viewed via TensorBoard. In order to take effect, it needs to be called before fit.

Training summary will be saved to ‘log_dir/app_name/train’ and validation summary (if any) will be saved to ‘log_dir/app_name/validation’.

# Arguments :param log_dir: The base directory path to store training and validation logs. :param app_name: The name of the application.

train(train_set, criterion, end_trigger=None, checkpoint_trigger=None, validation_set=None, validation_method=None, batch_size=32)[source]

Train model with provided trainSet and criterion. The training will end until the endTrigger is triggered. During the training, if checkPointTrigger is defined and triggered, the model will be saved to modelDir. And if validationSet and validationMethod is defined, the model will be evaluated at the checkpoint. :param train_set: training FeatureSet, a FeatureSet[Sample[T]] :param criterion: Loss function :param end_trigger: When to finish the training :param checkpoint_trigger: When to save a checkpoint and evaluate model. :param validation_set: Validation FeatureSet, a FeatureSet[Sample[T]] :param validation_method: Validation Methods. :param batch_size: :return: Estimator

train_imagefeature(train_set, criterion, end_trigger=None, checkpoint_trigger=None, validation_set=None, validation_method=None, batch_size=32)[source]

Train model with provided imageFeature trainSet and criterion. The training will end until the endTrigger is triggered. During the training, if checkPointTrigger is defined and triggered, the model will be saved to modelDir. And if validationSet and validationMethod is defined, the model will be evaluated at the checkpoint. :param train_set: training FeatureSet, a FeatureSet[ImageFeature] :param criterion: Loss function :param end_trigger: When to finish the training :param checkpoint_trigger: When to save a checkpoint and evaluate model. :param validation_set: Validation FeatureSet, a FeatureSet[Sample[T]] :param validation_method: Validation Methods. :param batch_size: Batch size :return:

train_minibatch(train_set, criterion, end_trigger=None, checkpoint_trigger=None, validation_set=None, validation_method=None)[source]

Train model with provided trainSet and criterion. The training will end until the endTrigger is triggered. During the training, if checkPointTrigger is defined and triggered, the model will be saved to modelDir. And if validationSet and validationMethod is defined, the model will be evaluated at the checkpoint. :param train_set: training FeatureSet, a FeatureSet[MiniBatch[T]] :param criterion: Loss function :param end_trigger: When to finish the training :param checkpoint_trigger: When to save a checkpoint and evaluate model. :param validation_set: Validation FeatureSet, a FeatureSet[MiniBatch[T]] :param validation_method: Validation Methods. :return: Estimator

Module contents