#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.optim.optimizer import *
from zoo.pipeline.api.keras.objectives import *
from zoo.pipeline.api.keras import metrics
[docs]def to_bigdl_optim_method(optimizer):
optimizer = optimizer.lower()
if optimizer == "adagrad":
return Adagrad(learningrate=0.01)
elif optimizer == "sgd":
return SGD(learningrate=0.01)
elif optimizer == "adam":
return Adam()
elif optimizer == "rmsprop":
return RMSprop(learningrate=0.001, decayrate=0.9)
elif optimizer == "adadelta":
return Adadelta(decayrate=0.95, epsilon=1e-8)
elif optimizer == "adamax":
return Adamax(epsilon=1e-8)
else:
raise TypeError("Unsupported optimizer: %s" % optimizer)
[docs]def to_bigdl_criterion(criterion):
criterion = criterion.lower()
if criterion == "categorical_crossentropy":
return CategoricalCrossEntropy()
elif criterion == "mse" or criterion == "mean_squared_error":
return MeanSquaredError()
elif criterion == "binary_crossentropy":
return BinaryCrossEntropy()
elif criterion == "mae" or criterion == "mean_absolute_error":
return mae()
elif criterion == "hinge":
return Hinge()
elif criterion == "mean_absolute_percentage_error" or criterion == "mape":
return MeanAbsolutePercentageError()
elif criterion == "mean_squared_logarithmic_error" or criterion == "msle":
return MeanSquaredLogarithmicError()
elif criterion == "squared_hinge":
return SquaredHinge()
elif criterion == "sparse_categorical_crossentropy":
return SparseCategoricalCrossEntropy()
elif criterion == "kullback_leibler_divergence" or criterion == "kld":
return KullbackLeiblerDivergence()
elif criterion == "poisson":
return Poisson()
elif criterion == "cosine_proximity" or criterion == "cosine":
return CosineProximity()
elif criterion == "rank_hinge":
return RankHinge()
else:
raise TypeError("Unsupported loss: %s" % criterion)
[docs]def to_bigdl_metric(metric, loss):
metric = metric.lower()
loss_str = (loss if isinstance(loss, six.string_types) else loss.__class__.__name__).lower()
if metric == "accuracy" or metric == "acc":
if loss_str == "sparse_categorical_crossentropy"\
or loss_str == "sparsecategoricalcrossentropy":
return metrics.SparseCategoricalAccuracy()
elif loss_str == "categorical_crossentropy"\
or loss_str == "categoricalcrossentropy":
return metrics.CategoricalAccuracy()
elif loss_str == "binary_crossentropy"\
or loss_str == "binarycrossentropy":
return metrics.BinaryAccuracy()
else:
raise TypeError(
"Not supported combination: metric {} and loss {}".format(metric, loss_str))
elif metric == "top5accuracy" or metric == "top5acc":
return metrics.Top5Accuracy()
elif metric == "mae":
return metrics.MAE()
elif metric == "auc":
return metrics.AUC()
elif metric == "loss":
return Loss(to_bigdl_criterion(loss_str))
elif metric == "treennaccuracy":
return TreeNNAccuracy()
else:
raise TypeError("Unsupported metric: %s" % metric)
[docs]def to_bigdl_metrics(metrics, loss):
return [to_bigdl_metric(m, loss) for m in metrics]