com.aliasi.classify
Class TradNaiveBayesClassifier

java.lang.Object
  extended by com.aliasi.classify.TradNaiveBayesClassifier
All Implemented Interfaces:
BaseClassifier<CharSequence>, Classifier<CharSequence,JointClassification>, ConditionalClassifier<CharSequence>, JointClassifier<CharSequence>, RankedClassifier<CharSequence>, ScoredClassifier<CharSequence>, ClassificationHandler<CharSequence,Classification>, Handler, ObjectHandler<Classified<CharSequence>>, Compilable, Serializable

public class TradNaiveBayesClassifier
extends Object
implements ClassificationHandler<CharSequence,Classification>, Classifier<CharSequence,JointClassification>, JointClassifier<CharSequence>, ObjectHandler<Classified<CharSequence>>, Serializable, Compilable

A TradNaiveBayesClassifier implements a traditional token-based approach to naive Bayes text classification. It wraps a tokenization factory to convert character sequences into sequences of tokens. This implementation supports several enhancements to simple naive Bayes: priors, length normalization, and semi-supervised training with EM.

It is the token counts (aka "bag of words") sequence that is actually being classified, not the raw character sequence input. So any character sequences that produce the same bags of tokens are considered equal.

Naive Bayes is trainable online, meaning that it can be given training instances one at a time, and at any point can be used as a classifier. Training cases consist of a character sequence and classification, as dictated by the interface ObjectHandler<Classified<CharSequence>>.

Given a character sequence, a naive Bayes classifier returns joint probability estimates of categories and tokens; this is reflected in its implementing the Classifier<CharSequence,JointClassification> interface. Note that this is the joint probability of the token counts, so sums of probabilities over all input character sequences will exceed 1.0. Typically, only the conditional probability estimates are used in practice.

If there is length normalization, the joint probabilities will not sum to 1.0 over all inputs and outputs. The conditional probabilities will always sum to 1.0.

Classification

A token-based naive Bayes classifier computes joint token count and category probabilities by factoring the joint into the marginal probability of a category times the conditinoal probability of the tokens given the category.
 p(tokens,cat) = p(tokens|cat) * p(cat)
Conditional probabilities are derived by applying Bayes's rule to invert the probability calculation:
 p(cat|tokens) = p(tokens,cat) / p(tokens)
              = p(tokens|cat) * p(cat) / p(tokens)
The tokens are assumed to be independent (this is the "naive" step):
 p(tokens|cat) = p(tokens[0]|cat) * ... * p(tokens[tokens.length-1]|cat)
              = Πi < tokens.length p(tokens[i]|cat)
Finally, an explicit marginalization allows us to compute the marginal distribution of tokens:
 p(tokens) = Σcat' p(tokens,cat')
          = Σcat' p(cat'|tokens) * p(cat')

Estimation with Priors

We now have defined the conditional probability p(cat|tokens) in terms of two distributions, the conditional probability of a token given a category p(token|cat), and the marginal probability of a category p(cat) (sometimes called the category's prior probability, though this shouldn't be confused with the usual Bayesian prior on model parameters).

Traditional naive Bayes uses a maximum a posterior (MAP) estimate of the multinomial distributions: p(cat) over the set of categories, and for each category cat, the multinomial distribution p(token|cat) over the set of tokens. Traditional naive Bayes employs the Dirichlet conjugate prior for multinomials, which is straightforward to compute by adding a fixed "prior count" to each count in the training data. This lends the traditional name "additive smoothing".

Two sets of counts are sufficient for estimating a traditional naive Bayes classifier. The first is tokenCount(w,c), the number of times token w appeared as a token in a training case for category c. The second is caseCount(c), which is the number of training cases for category c.

We assume prior counts α for the case counts and β for the token counts. These values are supplied in the constructor for this class. The estimates for category and token probabilities p' are most easily understood as proportions:

 p'(w|c) ∝ tokenCount(w,c) + β

   p'(c) ∝ caseCount(c) + α
The probability estimates p' are obtained through the usual normalization:
 p'(w|c) = ( tokenCount(w,c) + β ) / Σw ( tokenCount(w,c) + β )

   p'(c) = ( caseCount(c) + α ) / Σc ( caseCount(c) + α )

Maximum Likelihood Estimates

Although not traditionally used for naive Bayes, maximum likelihood estimates arise from setting the prior counts equal to zero (α = β = 0). The prior counts drop out of the equations to yield the maximum likelihood estimates p*:

 p*(w|c) = tokenCount(w,c) / Σw tokenCount(w,c)

   p*(c) = caseCount(c) / Σc caseCount(c)

Weighted and Conditional Training

Unlike traditional naive Bayes implementations, this class allows weighted training, including training directly from a conditional classification. When training using a conditional classification, each category is weighted according to its conditional probability.

Weights may be negative, allowing counts to be decremented (e.g. for Gibbs sampling).

Length Normalization

Because the (almost always faulty) independence of tokens assumptions underlying the naive Bayes classifier, the conditional probability estimates tend toward either 0.0 or 1.0 as the input grows longer. In practice, it sometimes help to length normalize the documents. That is, consider each document to be a given number of tokens long, lengthNorm.

Length normalization can be computed directly on the linear scale:

 pn(tokens|cat) = p(tokens|cat)(lengthNorm/tokens.length)
 
but is more easily understood on the log scale, where we multiply the length norm by the log probability normalized per token:
 log2 pn(tokens|cat) = lengthNorm * log2 p(tokens|c) / tokens.length
 
The length normalization parameter is supplied in the constructor, with a Double.NaN value indicating that no length normalization should be done.

Length normalization will be applied during training, too. Length normalization may be changed using the set method. For instance, this allows training to skip length normalization and classification to use length normalization.

Semi-Supervised Training with Expectation Maximization (EM)

Naive bayes is a common model to use in conjunction with the general semi-supervised or unsupervised training strategy known as expectation maximization (EM). The basic idea behind EM is is that it starts with a classifier, then applies it to unseen data, looks at the weighted output predictions, then uses the output predictions as training data.

EM is controlled by epoch. Each epoch consists of an expectation (E) step, followed by a maximization (M) step. The expectation step computes expectations which are then fed in as training weights to the maximization step.

The version of EM implemented in this class allows a mixture of supervised and unsupervised data.

The supervised training data is in the form of a corpus of classifications, implementing Corpus>}.

Unsupervised data is in the form of a corpus of texts, implementing Corpus<TextHandler>.

The method also requires a factory with which to produce a new classifier in each epoch, namely an implementation of Factory<TradNaiveBayesClassifier>. And it also takes an initial classifier, which may be different than the classifiers generated by the factory.

EM works by iteratively training better and better classifiers using the previous classifier to label unlabeled data to use for training.

 set lastClassifier to initialClassifier
 for (epoch = 0; epoch < maxEpochs; ++epoch) {
      create classifier using factory
      train classifier on supervised items
      for (x in unsupervised items) {
          compute p(c|x) with lastClassifier
          for (c in category)
              train classifier on c weighted by p(c|x)
      }
      evaluate corpus and model probability under classifier
      set lastClassifier to classifier
      break if converged
 }
 return lastClassifier

Note that in each round, the new classifier is trained on the supervised items.

In general, we have found that EM training works best if the initial classifier does more smoothing than the classifiers returned by the factory.

Annealing, of a sort, may be built in by having the factory return a sequence of classifiers with ever longer length normalizations and/or lower prior counts, both of which attenuate the posterior predictions of a naive Bayes classifier. With a short length normalization, classifications are driven closer to uniform; with longer length normalizations they are more peaky.

Unsupervised Learning and EM Soft Clustering

It is possible to train a classifier in a completely unsupervised fashion by having the initial classifier assign categories at random. Only the number of categories must be fixed. The algorithm is exactly the same, and the result after convergence or the maximum number of epochs is a classifier.

Now take the trained classifier and run it over the texts in the unsupervised text corpus. This will assign probabilities of the text belonging to each of the categories. This is known as a soft clustering, and the algorithm overall is known as EM clustering. If we assign each item to its most likely category, the result is then a hard clustering.

Serialization and Compilation

A naive Bayes classifier may be serialized. The object read back in will behave just as the naive Bayes classifier that was serialized. The tokenizer factory must be serializable in order to serialize the classifier.

A naive Bayes classifier may be compiled. The object read back in will provide a joint classifier implementing Classifier<CharSequence,JointClassification>. The compiled version precomputes the log probabilities to speed run-time computations. A compiled classifier may not be trained. In order to be compiled, the tokenizer factory must be either serializable or compilable.

Comparison to NaiveBayesClassifier

The naive Bayes classifier implemented in NaiveBayesClassifier differs from this version in smoothing the token estimates with character language model estimates.

Thread Safety

A TradNaiveBayesClassifier must be synchronized externally using read/write synchronization (e.g. with ReadWriteLock. The write methods include handle(Classified), train(CharSequence,Classification,double), trainConditional(CharSequence,ConditionalClassification,double,double), and setLengthNorm(double). All other methods are read methods.

A compiled classifier is completely thread safe.

Since:
Lingpipe3.8
Version:
3.9.1
Author:
Bob Carpenter
See Also:
Serialized Form

Constructor Summary
TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory)
          Constructs a naive Bayes classifier over the specified categories, using the specified tokenizer factory.
TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory, double categoryPrior, double tokenInCategoryPrior, double lengthNorm)
          Constructs a naive Bayes classifier over the specified categories, using the specified tokenizer factory, priors and length normalization.
 
Method Summary
 Set<String> categorySet()
          Returns a set of categories for this classifier.
 JointClassification classify(CharSequence in)
          Return the classification of the specified character sequence.
 void compileTo(ObjectOutput out)
          Compile this classifier to the specified object output.
static Iterator<TradNaiveBayesClassifier> em(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ClassificationHandler<CharSequence,Classification>> labeledData, Corpus<TextHandler> unlabeledData, double minTokenCount)
          Deprecated. Use emIterator(TradNaiveBayesClassifier,Factory,Corpus,Corpus,double) instead.
static TradNaiveBayesClassifier em(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ClassificationHandler<CharSequence,Classification>> labeledData, Corpus<TextHandler> unlabeledData, double minTokenCount, int maxEpochs, double minImprovement, Reporter reporter)
          Deprecated. Use emTrain(TradNaiveBayesClassifier,Factory,Corpus,Corpus,double,int,double,Reporter) instead.
static Iterator<TradNaiveBayesClassifier> emIterator(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount)
          Apply the expectation maximization (EM) algorithm to train a traditional naive Bayes classifier using the specified labeled and unabled data, initial classifier and factory for creating subsequent factories.
static TradNaiveBayesClassifier emTrain(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount, int maxEpochs, double minImprovement, Reporter reporter)
          Apply the expectation maximization (EM) algorithm to train a traditional naive Bayes classifier using the specified labeled and unabled data, initial classifier and factory for creating subsequent factories, maximum number of epochs, minimum improvement per epoch, and reporter to which progress reports are sent.
 void handle(CharSequence cSeq, Classification classification)
          Deprecated. Use handle(Classified) instead.
 void handle(Classified<CharSequence> classifiedObject)
          Trains the classifier with the specified classified character sequence.
 boolean isKnownToken(String token)
          Returns true if the token has been seen in training data.
 Set<String> knownTokenSet()
          Returns an unmodifiable view of the set of tokens.
 double lengthNorm()
          Returns the length normalization factor for this classifier.
 double log2CaseProb(CharSequence input)
          Returns the log (base 2) marginal probability of the specified input.
 double log2ModelProb()
          Returns the log (base 2) of the probability density of this model in the Dirichlet prior specified by this classifier.
 double probCat(String cat)
          Returns the probability estimate for the specified category.
 double probToken(String token, String cat)
          Returns the probability of the specified token in the specified category.
 void setLengthNorm(double lengthNorm)
          Set the length normalization factor to the specified value.
 String toString()
          Return a string representation of this classifier.
 void train(CharSequence cSeq, Classification classification, double count)
          Trains the classifier with the specified case consisting of a character sequence and conditional classification with the specified count.
 void trainConditional(CharSequence cSeq, ConditionalClassification classification, double countMultiplier, double minCount)
          Trains this classifier using tokens extracted from the specified character sequence, using category count multipliers derived by multiplying the specified count multiplier by the conditional probablity of a category in the specified classification.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Constructor Detail

TradNaiveBayesClassifier

public TradNaiveBayesClassifier(Set<String> categorySet,
                                TokenizerFactory tokenizerFactory)
Constructs a naive Bayes classifier over the specified categories, using the specified tokenizer factory. The category and token-in-category priors will be set to reasonable default value of 0.5, and there is no length normlization (length normalization set to Double.NaN).

See the class documentation above for more information.

Parameters:
categorySet - Categories for classification.
tokenizerFactory - Factory to convert char sequences to tokens.
Throws:
IllegalArgumentException - If there are fewer than two categories.

TradNaiveBayesClassifier

public TradNaiveBayesClassifier(Set<String> categorySet,
                                TokenizerFactory tokenizerFactory,
                                double categoryPrior,
                                double tokenInCategoryPrior,
                                double lengthNorm)
Constructs a naive Bayes classifier over the specified categories, using the specified tokenizer factory, priors and length normalization. See the class documentation for an explanation of the parameter's affect on classification.

Parameters:
categorySet - Categories for classification.
tokenizerFactory - Factory to convert char sequences to tokens.
categoryPrior - Prior count for categories.
tokenInCategoryPrior - Prior count for tokens per category.
lengthNorm - A positive, finite length norm, or Double.NaN if no length normalization is to be done.
Throws:
IllegalArgumentException - If either prior is negative or not finite, if there are fewer than two categories, or if the length normalization constant is negative, zero, or infinite.
Method Detail

toString

public String toString()
Return a string representation of this classifier.

Overrides:
toString in class Object
Returns:
String representation of this classifier.

categorySet

public Set<String> categorySet()
Returns a set of categories for this classifier.

Returns:
The set of categories for this classifier.

setLengthNorm

public void setLengthNorm(double lengthNorm)
Set the length normalization factor to the specified value. See the class documentation for

Parameters:
lengthNorm - Length normalization or Double.NaN to turn off normalization.
Throws:
IllegalArgumentException - If the length norm is infinite, zero, or negative.

classify

public JointClassification classify(CharSequence in)
Return the classification of the specified character sequence.

Specified by:
classify in interface BaseClassifier<CharSequence>
Specified by:
classify in interface Classifier<CharSequence,JointClassification>
Specified by:
classify in interface ConditionalClassifier<CharSequence>
Specified by:
classify in interface JointClassifier<CharSequence>
Specified by:
classify in interface RankedClassifier<CharSequence>
Specified by:
classify in interface ScoredClassifier<CharSequence>
Parameters:
in - Character sequence being classified.
Returns:
The classifcation of the char sequence.

lengthNorm

public double lengthNorm()
Returns the length normalization factor for this classifier. See the class documentation above for details.

Returns:
The length normalization for this classifier.

isKnownToken

public boolean isKnownToken(String token)
Returns true if the token has been seen in training data.

Parameters:
token - Token to test.
Returns:
true if the token has been seen in training data.

knownTokenSet

public Set<String> knownTokenSet()
Returns an unmodifiable view of the set of tokens. The set is not modifiable, but will change to reflect any tokens added during training.

Returns:
The set of known tokens.

probToken

public double probToken(String token,
                        String cat)
Returns the probability of the specified token in the specified category. See the class documentation above for definitions.

Throws:
IllegalArgumentException - If the category is not known or the token is not known.

compileTo

public void compileTo(ObjectOutput out)
               throws IOException
Compile this classifier to the specified object output. See the class documentation above for details.

Specified by:
compileTo in interface Compilable
Parameters:
out - Object output to which this classifier is compiled.
Throws:
IOException - If there is an underlying I/O error during the write.

probCat

public double probCat(String cat)
Returns the probability estimate for the specified category.

Parameters:
cat - Category whose probability is returned.
Returns:
Probability for category.
Throws:
IllegalArgumentException - If the category is not known.

handle

public void handle(Classified<CharSequence> classifiedObject)
Trains the classifier with the specified classified character sequence. Only the first-best result is used from the classification; to train on conditional outputs, see trainConditional(CharSequence,ConditionalClassification,double,double).

Specified by:
handle in interface ObjectHandler<Classified<CharSequence>>
Parameters:
classifiedObject - Classified character sequence.

handle

@Deprecated
public void handle(CharSequence cSeq,
                              Classification classification)
Deprecated. Use handle(Classified) instead.

Trains the classifier with the specified case consisting of a character sequence and first-best classification. Only the first-best result is used from the classification; to train on conditional outputs, see trainConditional(CharSequence,ConditionalClassification,double,double).

Specified by:
handle in interface ClassificationHandler<CharSequence,Classification>
Parameters:
cSeq - Character sequence being classified.
classification - Classification of character sequence.

trainConditional

public void trainConditional(CharSequence cSeq,
                             ConditionalClassification classification,
                             double countMultiplier,
                             double minCount)
Trains this classifier using tokens extracted from the specified character sequence, using category count multipliers derived by multiplying the specified count multiplier by the conditional probablity of a category in the specified classification. A category is not trained for the sequence if its conditional probability times the count multiplier is less than the minimum count.

Parameters:
cSeq - Character sequence being trained.
classification - Conditional classification to train.
countMultiplier - Count multiplier of training instance.
minCount - Minimum count for which a category is trained for this character sequence.
Throws:
IllegalArgumentException - If the countMultiplier is not finite and non-negative, or if the min count is below zero or not a number.

train

public void train(CharSequence cSeq,
                  Classification classification,
                  double count)
Trains the classifier with the specified case consisting of a character sequence and conditional classification with the specified count.

If the count value is negative, counts are subtracted rather than added. If any of the counts fall below zero, an illegal argument exception will be thrown and the classifier will be reverted to the counts in place before the method was called. Cleanup after errors requires the tokenizer factory to return the same tokenizer given the same string, but no check is made that it does.

Parameters:
cSeq - Character sequence on which to train.
classification - Classification to train with character sequence.
count - How many instances the classification will count as for training purposes.
Throws:
IllegalArgumentException - If the count is negative and increments cause accumulated counts to fall below zero.

log2CaseProb

public double log2CaseProb(CharSequence input)
Returns the log (base 2) marginal probability of the specified input. This value is calculated by:
p(x) = Σc in cats p(c,x)
Note that this value is normalized by the number of tokens in the input, so that
Σlength(x) = n p(x) = 1.0

Parameters:
input - Input character sequence.
Returns:
The log probability of the input under this joint model.

log2ModelProb

public double log2ModelProb()
Returns the log (base 2) of the probability density of this model in the Dirichlet prior specified by this classifier. Note that the result is a log density is not technically a probability, and may return values that are positive.

The result is the sum of the log density of the multinomial over categories and the log density of the per-category multinomials over tokens.

For a definition of the probability function for each category's multinomial and the overall category multinomial, see Statistics.dirichletLog2Prob(double,double[]).

Returns:
The log model density value.

em

@Deprecated
public static Iterator<TradNaiveBayesClassifier> em(TradNaiveBayesClassifier initialClassifier,
                                                               Factory<TradNaiveBayesClassifier> classifierFactory,
                                                               Corpus<ClassificationHandler<CharSequence,Classification>> labeledData,
                                                               Corpus<TextHandler> unlabeledData,
                                                               double minTokenCount)
                                             throws IOException
Deprecated. Use emIterator(TradNaiveBayesClassifier,Factory,Corpus,Corpus,double) instead.

Apply the expectation maximization (EM) algorithm to train a traditional naive Bayes classifier using the specified labeled and unabled data, initial classifier and factory for creating subsequent factories.

This method lets the client take control over assessing convergence, so there are no convergence-related arguments.

Parameters:
initialClassifier - Initial classifier to bootstrap.
classifierFactory - Factory for creating subsequent classifiers.
labeledData - Labeled data for supervised trianing.
unlabeledData - Unlabeled data for unsupervised training.
minTokenCount - Min count for a word to not be pruned.
Returns:
An iterator over classifiers that returns each epoch's classifier.
Throws:
IOException

emIterator

public static Iterator<TradNaiveBayesClassifier> emIterator(TradNaiveBayesClassifier initialClassifier,
                                                            Factory<TradNaiveBayesClassifier> classifierFactory,
                                                            Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
                                                            Corpus<ObjectHandler<CharSequence>> unlabeledData,
                                                            double minTokenCount)
                                                     throws IOException
Apply the expectation maximization (EM) algorithm to train a traditional naive Bayes classifier using the specified labeled and unabled data, initial classifier and factory for creating subsequent factories.

This method lets the client take control over assessing convergence, so there are no convergence-related arguments.

Parameters:
initialClassifier - Initial classifier to bootstrap.
classifierFactory - Factory for creating subsequent classifiers.
labeledData - Labeled data for supervised trianing.
unlabeledData - Unlabeled data for unsupervised training.
minTokenCount - Min count for a word to not be pruned.
Returns:
An iterator over classifiers that returns each epoch's classifier.
Throws:
IOException

em

@Deprecated
public static TradNaiveBayesClassifier em(TradNaiveBayesClassifier initialClassifier,
                                                     Factory<TradNaiveBayesClassifier> classifierFactory,
                                                     Corpus<ClassificationHandler<CharSequence,Classification>> labeledData,
                                                     Corpus<TextHandler> unlabeledData,
                                                     double minTokenCount,
                                                     int maxEpochs,
                                                     double minImprovement,
                                                     Reporter reporter)
                                   throws IOException
Deprecated. Use emTrain(TradNaiveBayesClassifier,Factory,Corpus,Corpus,double,int,double,Reporter) instead.

Apply the expectation maximization (EM) algorithm to train a traditional naive Bayes classifier using the specified labeled and unabled data, initial classifier and factory for creating subsequent factories, maximum number of epochs, minimum improvement per epoch, and reporter to which progress reports are sent.

Parameters:
initialClassifier - Initial classifier to bootstrap.
classifierFactory - Factory for creating subsequent classifiers.
labeledData - Labeled data for supervised trianing.
unlabeledData - Unlabeled data for unsupervised training.
minTokenCount - Min count for a word to not be pruned.
maxEpochs - Maximum number of epochs to run training.
minImprovement - Minimum relative improvement per epoch.
reporter - Reporter to which intermediate results are reported, or null for no reporting.
Returns:
The trained classifier.
Throws:
IOException

emTrain

public static TradNaiveBayesClassifier emTrain(TradNaiveBayesClassifier initialClassifier,
                                               Factory<TradNaiveBayesClassifier> classifierFactory,
                                               Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
                                               Corpus<ObjectHandler<CharSequence>> unlabeledData,
                                               double minTokenCount,
                                               int maxEpochs,
                                               double minImprovement,
                                               Reporter reporter)
                                        throws IOException
Apply the expectation maximization (EM) algorithm to train a traditional naive Bayes classifier using the specified labeled and unabled data, initial classifier and factory for creating subsequent factories, maximum number of epochs, minimum improvement per epoch, and reporter to which progress reports are sent.

Parameters:
initialClassifier - Initial classifier to bootstrap.
classifierFactory - Factory for creating subsequent classifiers.
labeledData - Labeled data for supervised trianing.
unlabeledData - Unlabeled data for unsupervised training.
minTokenCount - Min count for a word to not be pruned.
maxEpochs - Maximum number of epochs to run training.
minImprovement - Minimum relative improvement per epoch.
reporter - Reporter to which intermediate results are reported, or null for no reporting.
Returns:
The trained classifier.
Throws:
IOException