zoo.models.textclassification package

Submodules

zoo.models.textclassification.text_classifier module

class zoo.models.textclassification.text_classifier.TextClassifier(class_num, embedding_file, word_index=None, sequence_length=500, encoder='cnn', encoder_output_dim=256, **kwargs)[source]

Bases: zoo.models.common.zoo_model.ZooModel

The model used for text classification with WordEmbedding as its first layer.

# Arguments class_num: The number of text categories to be classified. Positive int. embedding_file: The path to the word embedding file.

Currently only the following GloVe files are supported: “glove.6B.50d.txt”, “glove.6B.100d.txt”, “glove.6B.200d.txt” “glove.6B.300d.txt”, “glove.42B.300d.txt”, “glove.840B.300d.txt”. You can download from: https://nlp.stanford.edu/projects/glove/.
word_index: Dictionary of word (string) and its corresponding index (int).
The index is supposed to start from 1 with 0 reserved for unknown words. During the prediction, if you have words that are not in the word_index for the training, you can map them to index 0. Default is None. In this case, all the words in the embedding_file will be taken into account and you can call WordEmbedding.get_word_index(embedding_file) to retrieve the dictionary.

sequence_length: The length of a sequence. Positive int. Default is 500. encoder: The encoder for input sequences. String. ‘cnn’ or ‘lstm’ or ‘gru’ are supported.

Default is ‘cnn’.

encoder_output_dim: The output dimension for the encoder. Positive int. Default is 256.

build_model()[source]
compile(optimizer, loss, metrics=None)[source]
evaluate(x, batch_size=32)[source]

Evaluate on TextSet.

fit(x, batch_size=32, nb_epoch=10, validation_data=None)[source]

Fit on TextSet.

static load_model(path, weight_path=None, bigdl_type='float')[source]

Load an existing TextClassifier model (with weights).

# Arguments path: The path for the pre-defined model.

Local file system, HDFS and Amazon S3 are supported. HDFS path should be like ‘hdfs://[host]:[port]/xxx’. Amazon S3 path should be like ‘s3a://bucket/xxx’.

weight_path: The path for pre-trained weights if any. Default is None.

predict(x, batch_per_thread=4)[source]

Predict on TextSet.

set_checkpoint(path, over_write=True)[source]
set_tensorboard(log_dir, app_name)[source]

Module contents