zoo.models.common package¶
Submodules¶
zoo.models.common.ranker module¶
-
class
zoo.models.common.ranker.Ranker(jvalue, bigdl_type, *args)[source]¶ Bases:
bigdl.util.common.JavaValueBase class for Ranking models (e.g., TextMatcher and Ranker) that provides validation methods with different metrics.
-
evaluate_map(x, threshold=0.0)[source]¶ Evaluate using mean average precision on TextSet.
Parameters: - x – TextSet. Each TextFeature should contain Sample with batch features and labels. In other words, each Sample should be a batch of records having both positive and negative labels.
- threshold – Float. If label > threshold, then it will be considered as a positive record. Default is 0.0.
Returns: Float. MAP result.
-
evaluate_ndcg(x, k, threshold=0.0)[source]¶ Evaluate using normalized discounted cumulative gain on TextSet.
Parameters: - x – TextSet. Each TextFeature should contain Sample with batch features and labels. In other words, each Sample should be a batch of records having both positive and negative labels.
- k – Positive int. Rank position.
- threshold – Float. If label > threshold, then it will be considered as a positive record. Default is 0.0.
Returns: Float. NDCG result.
-
zoo.models.common.zoo_model module¶
-
class
zoo.models.common.zoo_model.KerasZooModel(jvalue, bigdl_type, *args)[source]¶ Bases:
zoo.models.common.zoo_model.ZooModelThe base class for Keras style models in Analytics Zoo.
-
evaluate(x, y=None, batch_size=32)[source]¶ No argument passed in: Evaluate the model to set train = false, useful when doing test/forward :return: layer itself
Three arguments passed in: A method to benchmark the model quality.
Parameters: - dataset – the input data
- batch_size – batch size
- val_methods – a list of validation methods. i.e: Top1Accuracy,Top5Accuracy and Loss.
Returns: a list of the metrics result
-
fit(x, y=None, batch_size=32, nb_epoch=10, validation_split=0, validation_data=None, distributed=True)[source]¶
-
predict(x, batch_per_thread=4, distributed=True)[source]¶ Model inference base on the given data. :param features: it can be a ndarray or list of ndarray for locally inference
or RDD[Sample] for running in distributed fashionParameters: batch_size – total batch size of prediction. Returns: ndarray or RDD[Sample] depend on the the type of features.
-
predict_classes(x, batch_per_thread=4, distributed=True)[source]¶ Predict for classes. By default, label predictions start from 0.
# Arguments x: Prediction data. A Numpy array or RDD of Sample. batch_size: Number of samples per batch. Default is 32. zero_based_label: Boolean. Whether result labels start from 0.
Default is True. If False, result labels start from 1.
-
-
class
zoo.models.common.zoo_model.ZooModel(jvalue, bigdl_type, *args)[source]¶ Bases:
zoo.models.common.zoo_model.ZooModelCreator,bigdl.nn.layer.ContainerThe base class for models in Analytics Zoo.
-
predict_classes(x, batch_size=32, zero_based_label=True)[source]¶ Predict for classes. By default, label predictions start from 0.
# Arguments x: Prediction data. A Numpy array or RDD of Sample. batch_size: Number of samples per batch. Default is 32. zero_based_label: Boolean. Whether result labels start from 0.
Default is True. If False, result labels start from 1.
-
save_model(path, weight_path=None, over_write=False)[source]¶ Save the model to the specified path.
# Arguments path: The path to save the model. Local file system, HDFS and Amazon S3 are supported.
HDFS path should be like ‘hdfs://[host]:[port]/xxx’. Amazon S3 path should be like ‘s3a://bucket/xxx’.weight_path: The path to save weights. Default is None. over_write: Whether to overwrite the file if it already exists. Default is False.
-