zoo.models.textclassification package¶
Submodules¶
zoo.models.textclassification.text_classifier module¶
-
class
zoo.models.textclassification.text_classifier.TextClassifier(class_num, embedding_file, word_index=None, sequence_length=500, encoder='cnn', encoder_output_dim=256, **kwargs)[source]¶ Bases:
zoo.models.common.zoo_model.ZooModelThe model used for text classification with WordEmbedding as its first layer.
# Arguments class_num: The number of text categories to be classified. Positive int. embedding_file: The path to the word embedding file.
Currently only the following GloVe files are supported: “glove.6B.50d.txt”, “glove.6B.100d.txt”, “glove.6B.200d.txt” “glove.6B.300d.txt”, “glove.42B.300d.txt”, “glove.840B.300d.txt”. You can download from: https://nlp.stanford.edu/projects/glove/.- word_index: Dictionary of word (string) and its corresponding index (int).
- The index is supposed to start from 1 with 0 reserved for unknown words. During the prediction, if you have words that are not in the word_index for the training, you can map them to index 0. Default is None. In this case, all the words in the embedding_file will be taken into account and you can call WordEmbedding.get_word_index(embedding_file) to retrieve the dictionary.
sequence_length: The length of a sequence. Positive int. Default is 500. encoder: The encoder for input sequences. String. ‘cnn’ or ‘lstm’ or ‘gru’ are supported.
Default is ‘cnn’.encoder_output_dim: The output dimension for the encoder. Positive int. Default is 256.
-
static
load_model(path, weight_path=None, bigdl_type='float')[source]¶ Load an existing TextClassifier model (with weights).
# Arguments path: The path for the pre-defined model.
Local file system, HDFS and Amazon S3 are supported. HDFS path should be like ‘hdfs://[host]:[port]/xxx’. Amazon S3 path should be like ‘s3a://bucket/xxx’.weight_path: The path for pre-trained weights if any. Default is None.