zoo.models.textmatching package

Submodules

zoo.models.textmatching.knrm module

class zoo.models.textmatching.knrm.KNRM(text1_length, text2_length, embedding_file, word_index=None, train_embed=True, kernel_num=21, sigma=0.1, exact_sigma=0.001, target_mode='ranking', bigdl_type='float')[source]

Bases: zoo.models.textmatching.text_matcher.TextMatcher

Kernel-pooling Neural Ranking Model with RBF kernel. https://arxiv.org/abs/1706.06613

# Arguments: text1_length: Sequence length of text1 (query). text2_length: Sequence length of text2 (doc). embedding_file: The path to the word embedding file.

Currently only the following GloVe files are supported: “glove.6B.50d.txt”, “glove.6B.100d.txt”, “glove.6B.200d.txt” “glove.6B.300d.txt”, “glove.42B.300d.txt”, “glove.840B.300d.txt”. You can download from: https://nlp.stanford.edu/projects/glove/.
word_index: Dictionary of word (string) and its corresponding index (int).
The index is supposed to start from 1 with 0 reserved for unknown words. During the prediction, if you have words that are not in the word_index for the training, you can map them to index 0. Default is None. In this case, all the words in the embedding_file will be taken into account and you can call WordEmbedding.get_word_index(embedding_file) to retrieve the dictionary.

train_embed: Boolean. Whether to train the embedding layer or not. Default is True. kernel_num: Int > 1. The number of kernels to use. Default is 21. sigma: Float. Defines the kernel width, or the range of its softTF count. Default is 0.1. exact_sigma: Float. The sigma used for the kernel that harvests exact matches

in the case where RBF mu=1.0. Default is 0.001.
target_mode: String. The target mode of the model. Either ‘ranking’ or ‘classification’.
For ranking, the output will be the relevance score between text1 and text2 and you are recommended to use ‘rank_hinge’ as loss for pairwise training. For classification, the last layer will be sigmoid and the output will be the probability between 0 and 1 indicating whether text1 is related to text2 and you are recommended to use ‘binary_crossentropy’ as loss for binary classification. Default mode is ‘ranking’.
build_model()[source]
static load_model(path, weight_path=None, bigdl_type='float')[source]

Load an existing KNRM model (with weights).

# Arguments path: The path for the pre-defined model.

Local file system, HDFS and Amazon S3 are supported. HDFS path should be like ‘hdfs://[host]:[port]/xxx’. Amazon S3 path should be like ‘s3a://bucket/xxx’.

weight_path: The path for pre-trained weights if any. Default is None.

zoo.models.textmatching.text_matcher module

class zoo.models.textmatching.text_matcher.TextMatcher(text1_length, vocab_size, embed_size=300, embed_weights=None, train_embed=True, target_mode='ranking', bigdl_type='float')[source]

Bases: zoo.models.common.zoo_model.ZooModel, zoo.models.common.ranker.Ranker

The base class for text matching models in Analytics Zoo. Referred to MatchZoo implementation: https://github.com/NTMC-Community/MatchZoo

Module contents