Source code for zoo.models.textclassification.text_classifier

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import warnings

from zoo.pipeline.api.keras.models import Sequential
from zoo.pipeline.api.keras.layers import *
from zoo.models.common import ZooModel
from zoo.common.utils import callZooFunc

if sys.version >= '3':
    long = int
    unicode = str


[docs]class TextClassifier(ZooModel): """ The model used for text classification with WordEmbedding as its first layer. # Arguments class_num: The number of text categories to be classified. Positive int. embedding_file: The path to the word embedding file. Currently only the following GloVe files are supported: "glove.6B.50d.txt", "glove.6B.100d.txt", "glove.6B.200d.txt" "glove.6B.300d.txt", "glove.42B.300d.txt", "glove.840B.300d.txt". You can download from: https://nlp.stanford.edu/projects/glove/. word_index: Dictionary of word (string) and its corresponding index (int). The index is supposed to start from 1 with 0 reserved for unknown words. During the prediction, if you have words that are not in the word_index for the training, you can map them to index 0. Default is None. In this case, all the words in the embedding_file will be taken into account and you can call WordEmbedding.get_word_index(embedding_file) to retrieve the dictionary. sequence_length: The length of a sequence. Positive int. Default is 500. encoder: The encoder for input sequences. String. 'cnn' or 'lstm' or 'gru' are supported. Default is 'cnn'. encoder_output_dim: The output dimension for the encoder. Positive int. Default is 256. """ def __init__(self, class_num, embedding_file, word_index=None, sequence_length=500, encoder="cnn", encoder_output_dim=256, **kwargs): if 'token_length' in kwargs: kwargs.pop('token_length') warnings.warn('The "token_length" argument in TextClassifier has been deprecated ' 'since 0.3.0, instead you should pass the arguments "embedding_file" ' 'and "word_index" to construct a TextClassifier with WordEmbedding ' 'as the first layer.') if 'bigdl_type' in kwargs: kwargs.pop('bigdl_type') self.bigdl_type = kwargs.get("bigdl_type") else: self.bigdl_type = "float" if kwargs: raise TypeError('Wrong arguments for TextClassifier: ' + str(kwargs)) self.class_num = class_num self.embedding = WordEmbedding(embedding_file, word_index, input_length=sequence_length) self.sequence_length = sequence_length self.encoder = encoder self.encoder_output_dim = encoder_output_dim self.model = self.build_model() super(TextClassifier, self).__init__(None, self.bigdl_type, int(class_num), self.embedding, int(sequence_length), encoder, int(encoder_output_dim), self.model)
[docs] def build_model(self): model = Sequential() model.add(self.embedding) if self.encoder.lower() == 'cnn': model.add(Convolution1D(self.encoder_output_dim, 5, activation='relu')) model.add(GlobalMaxPooling1D()) elif self.encoder.lower() == 'lstm': model.add(LSTM(self.encoder_output_dim)) elif self.encoder.lower() == 'gru': model.add(GRU(self.encoder_output_dim)) else: raise ValueError('Unsupported encoder for TextClassifier: ' + self.encoder) model.add(Dense(128)) model.add(Dropout(0.2)) model.add(Activation('relu')) model.add(Dense(self.class_num, activation='softmax')) return model
[docs] @staticmethod def load_model(path, weight_path=None, bigdl_type="float"): """ Load an existing TextClassifier model (with weights). # Arguments path: The path for the pre-defined model. Local file system, HDFS and Amazon S3 are supported. HDFS path should be like 'hdfs://[host]:[port]/xxx'. Amazon S3 path should be like 's3a://bucket/xxx'. weight_path: The path for pre-trained weights if any. Default is None. """ jmodel = callZooFunc(bigdl_type, "loadTextClassifier", path, weight_path) model = ZooModel._do_load(jmodel, bigdl_type) model.__class__ = TextClassifier return model
# For the following methods, please refer to KerasNet for documentation.
[docs] def compile(self, optimizer, loss, metrics=None): if isinstance(optimizer, six.string_types): optimizer = to_bigdl_optim_method(optimizer) if isinstance(loss, six.string_types): loss = to_bigdl_criterion(loss) if metrics and all(isinstance(metric, six.string_types) for metric in metrics): metrics = to_bigdl_metrics(metrics, loss) callZooFunc(self.bigdl_type, "textClassifierCompile", self.value, optimizer, loss, metrics)
[docs] def set_tensorboard(self, log_dir, app_name): callZooFunc(self.bigdl_type, "textClassifierSetTensorBoard", self.value, log_dir, app_name)
[docs] def set_checkpoint(self, path, over_write=True): callZooFunc(self.bigdl_type, "textClassifierSetCheckpoint", self.value, path, over_write)
[docs] def fit(self, x, batch_size=32, nb_epoch=10, validation_data=None): """ Fit on TextSet. """ assert isinstance(x, TextSet), "x should be a TextSet" if validation_data: assert isinstance(validation_data, TextSet), "validation_data should be a TextSet" callZooFunc(self.bigdl_type, "textClassifierFit", self.value, x, batch_size, nb_epoch, validation_data)
[docs] def evaluate(self, x, batch_size=32): """ Evaluate on TextSet. """ assert isinstance(x, TextSet), "x should be a TextSet" return callZooFunc(self.bigdl_type, "textClassifierEvaluate", self.value, x, batch_size)
[docs] def predict(self, x, batch_per_thread=4): """ Predict on TextSet. """ assert isinstance(x, TextSet), "x should be a TextSet" results = callZooFunc(self.bigdl_type, "textClassifierPredict", self.value, x, batch_per_thread) return TextSet(results)