zoo.models.common package

Submodules

zoo.models.common.ranker module

class zoo.models.common.ranker.Ranker(jvalue, bigdl_type, *args)[source]

Bases: bigdl.util.common.JavaValue

Base class for Ranking models (e.g., TextMatcher and Ranker) that provides validation methods with different metrics.

evaluate_map(x, threshold=0.0)[source]

Evaluate using mean average precision on TextSet.

Parameters:
  • x – TextSet. Each TextFeature should contain Sample with batch features and labels. In other words, each Sample should be a batch of records having both positive and negative labels.
  • threshold – Float. If label > threshold, then it will be considered as a positive record. Default is 0.0.
Returns:

Float. MAP result.

evaluate_ndcg(x, k, threshold=0.0)[source]

Evaluate using normalized discounted cumulative gain on TextSet.

Parameters:
  • x – TextSet. Each TextFeature should contain Sample with batch features and labels. In other words, each Sample should be a batch of records having both positive and negative labels.
  • k – Positive int. Rank position.
  • threshold – Float. If label > threshold, then it will be considered as a positive record. Default is 0.0.
Returns:

Float. NDCG result.

zoo.models.common.zoo_model module

class zoo.models.common.zoo_model.KerasZooModel(jvalue, bigdl_type, *args)[source]

Bases: zoo.models.common.zoo_model.ZooModel

The base class for Keras style models in Analytics Zoo.

clear_gradient_clipping()[source]
compile(optimizer, loss, metrics=None)[source]
evaluate(x, y=None, batch_size=32)[source]

No argument passed in: Evaluate the model to set train = false, useful when doing test/forward :return: layer itself

Three arguments passed in: A method to benchmark the model quality.

Parameters:
  • dataset – the input data
  • batch_size – batch size
  • val_methods – a list of validation methods. i.e: Top1Accuracy,Top5Accuracy and Loss.
Returns:

a list of the metrics result

fit(x, y=None, batch_size=32, nb_epoch=10, validation_split=0, validation_data=None, distributed=True)[source]
get_train_summary(tag=None)[source]
get_validation_summary(tag=None)[source]
predict(x, batch_per_thread=4, distributed=True)[source]

Model inference base on the given data. :param features: it can be a ndarray or list of ndarray for locally inference

or RDD[Sample] for running in distributed fashion
Parameters:batch_size – total batch size of prediction.
Returns:ndarray or RDD[Sample] depend on the the type of features.
predict_classes(x, batch_per_thread=4, distributed=True)[source]

Predict for classes. By default, label predictions start from 0.

# Arguments x: Prediction data. A Numpy array or RDD of Sample. batch_size: Number of samples per batch. Default is 32. zero_based_label: Boolean. Whether result labels start from 0.

Default is True. If False, result labels start from 1.
set_checkpoint(path, over_write=True)[source]
set_constant_gradient_clipping(min, max)[source]
set_evaluate_status()[source]

Set the model to be in evaluate status, i.e. remove the effect of Dropout, etc.

set_gradient_clipping_by_l2_norm(clip_norm)[source]
set_tensorboard(log_dir, app_name)[source]
class zoo.models.common.zoo_model.ZooModel(jvalue, bigdl_type, *args)[source]

Bases: zoo.models.common.zoo_model.ZooModelCreator, bigdl.nn.layer.Container

The base class for models in Analytics Zoo.

predict_classes(x, batch_size=32, zero_based_label=True)[source]

Predict for classes. By default, label predictions start from 0.

# Arguments x: Prediction data. A Numpy array or RDD of Sample. batch_size: Number of samples per batch. Default is 32. zero_based_label: Boolean. Whether result labels start from 0.

Default is True. If False, result labels start from 1.
save_model(path, weight_path=None, over_write=False)[source]

Save the model to the specified path.

# Arguments path: The path to save the model. Local file system, HDFS and Amazon S3 are supported.

HDFS path should be like ‘hdfs://[host]:[port]/xxx’. Amazon S3 path should be like ‘s3a://bucket/xxx’.

weight_path: The path to save weights. Default is None. over_write: Whether to overwrite the file if it already exists. Default is False.

set_evaluate_status()[source]

Set the model to be in evaluate status, i.e. remove the effect of Dropout, etc.

summary()[source]

Print out the summary of the model.

class zoo.models.common.zoo_model.ZooModelCreator(jvalue, bigdl_type, *args)[source]

Bases: bigdl.util.common.JavaValue

jvm_class_constructor()[source]

Module contents