What is Expectation Maximization?

Expectation maximization (EM) is a very general technique for finding posterior modes of mixture models using a combination of supervised and unsupervised data.

This Tutorial

In this tutorial, we will explain the basic form of the EM algorithm, and go into depth on an application to classification using a multinomial (aka naive Bayes) classification model. In particular, we will take a handful of labeled training data and use it to bootstrap a classifier using unlabeled training data to help with estimation.

In this tutorial, we replicate the basic EM results of Nigam, McCallum and Mitchell's (2006) paper describing EM for naive Bayes [see references].

The EM Algorithm

The EM algorithm iteratively computes expectations (the E step) given a current model, using them as training data to maximize a model estimate (the M step). The reason it works is that the combination of an E and M step is guaranteed to reduce error, and because error is bounded at zero. Therefore, the algorithm must converge.

EM (Soft) Clustering

EM works when there is no supervised training data at all. An initial classifier needs to be built randomly (perhaps using a good seeding method like k-means++).

Once the first model is initialized, EM deterministically hill climbs until it finds a set of parameters yielding a locally minimum error. Standard operating procedure is to provide a number of different random starting point and choose the best result.

Traditional Naive Bayes

For EM, we use a traditional form of the so-called "naive Bayes" classifier. Although there is nothing special about naive Bayes per se, it is a particularly simple model with which to illustrate EM. It also performs well in practice.

LingPipe's Implementation

LingPipe implements the traditional naive Bayes algorithm in the class classify.TradNaiveBayes. In the rest of this section, we'll explain how this class works.

Bags of Words

Naive Bayes text classifiers typically involve a so-called "bag of words" representation. Specifically, this is a count of tokens occurring in some span of text; some may have count zero, some may have counts in the hundreds. In LingPipe, these are pulled out using tokenizers, which are created from text sequences using tokenizer factories.

Category Distribution

Categories are represented as strings, and each naive Bayes classifier fixes its set of categories at construction time. Naive Bayes estimates a multinomial distribution over categories, p(cat). This is sometimes called the prior distribution of categories, though shouldn't be confused with Bayesian priors on model parameters.

Token Distribution in Categories

For each category c, naive Bayes estimates a multinomial distribution over words, which we write as p(w|c), indicating the dependence of the probability of word w on category c.

Maixmum Likelihood Estimation

The maximum likelihood estimate of a naive Bayes model is computed by simple frequncies. The category distribution is estimated by:

p'(c) = freq(c) / Σc' freq(c')

where freq(c) is the number of times the category c showed up in the training data, with the denominator being the total number of training instances (each instance has a unique category).

The maximum likelihood estimates for words in a category are computed similarly:

p'(w|c) = freq(w,c) / Σw' freq(w',c)

where freq(w,c) is the number of times the word w appeared in a document labeled with category c.

Smoothing with Dirichlet (Additive) Priors

Maximum likelihood estimates provide estimates of zero probability for words that were not seen in a category during training. To overcome this gross underestimation, it is common to smoothe the estimated distribution. A typical way of doing this corresponds to assigning a Dirichlet prior to the multinomial parameters (this is a typical Bayesian prior), then computing the maximum a posteriori (MAP) estimate instead of the maximum likelihood estimate.

In practice, this works out to a technique introduced by Laplace, and known as additive smoothing. The Dirichlet prior is parameterized by a prior number of counts per outcome. For instance, we can take 0.5 as the prior number of counts for a category, then estimate:

p"(c) = (freq(c) + 0.5) / Σc' (freq(c') + 0.5)

and similarly for the word in category estimates, which we might give a 0.01 prior count:

p"(w|c) = (freq(w,c) + 0.01) / Σw' (freq(w',c) + 0.01)

Inference

Given estimates of p(c) and p(w|c), we can classify new texts consisting of a sequence of words ws as follows, using Bayes's rule (which is where the technique gets its name):

p(c|ws) = p(ws|c) * p(c) / p(ws)

We expand out the probability of words assuming each word is generated independently (this is the naive assumption from which the technique gets its name), and hence the probabilty of all the words is the product of the probability of each word:

p(ws|c) = Πi p(ws[i]|c)

In statistical terms, the naive assumption is that the distributions of words is a multinomial.

We can compute the marginal p(ws) by summation over its probability in each category weighted by the probability of that category:

p(ws) = Σc' p(ws|c') * p(c')

Attenuation of Naive Bayes

The reason naive Bayes is said to be naive is that it considers each word in the document to be independent. In reality, we know word distributions per documents are much more dispersed than this simple multinomial model would indicate. Particularly, if a word occurs once, it's much more likely to occur again, as are other topically related words. For instance, if I mention baseball and pitching, the word "hitting" is much more likely to show up than it would be in a random document.

The result of this (clearly false) independence assumption is a kind of attenuation of answers. If a word shows up ten times, it contributes p(w|c)10 to the total probability of the sequence of words. As a result, the conditional inferences p(c|ws) tend toward 0 or 1 as the length of inputs increase.

For instance, consider a model with two words hee and haw, and two categories, c1 and c2. If we assume the model:

p(c1) = p(c2) = 0.5

p(hee|c1) = 0.8      p(hee|c2) = 0.5
p(haw|c1) = 0.2      p(haw|c2) = 0.5

we can work out p(c|hee) as:

p(hee,c1) = p(hee|c1) * p(c1) = 0.8 * 0.5 = 0.4
p(hee,c2) = p(hee|c2) * p(c2) = 0.5 * 0.5 = 0.25

p(c1|hee) = p(hee,c1) / p(hee,c1) + p(hee,c2)

          = 0.4/(0.4 + 0.25) = .615

p(c2|hee) = p(hee,c2) / p(hee,c1) + p(hee,c2)

          = 0.25/(0.4 + 0.25) = .385

But now consider what happens with longer sequences of hees.

p(hee,hee|c1) = 0.8 * 0.8
p(hee,hee|c2) = 0.5 * 0.5

p(heen|c1) = 0.8n
p(heen|c2) = 0.5n

p(c1|heen) = 0.8n * 0.5 / (0.8n * 0.5 + 0.5n * 0.5)
           = 0.8n / (0.8n + 0.5n)

So as the number of hees increases, the probablity estimate for category c1 approaches 1.0. For instance, with n = 10, we have p(c1|hee10) = 0.991.

Training (or Testing) Error

The standard error measurement for statistical classifiers is log loss, which is defined for a sequence wss of token sequences and a parallel sequence cs of categories as:

log p(wss,cs|θ) p(θ) = log p(θ) + Σi log p(wss[i]|cs[i]) p(cs[i])

where θ is the set of parameters, p(ws|c) and p(c) are as defined above, and where p(θ) is the probability of the parameters (in their respective Dirichlet prior densities). Explaining p(θ) is beyond the scope of this tutorial, but see the LingPipe method stats.Statistics.dirichletLog2Prob() for full details.

Length Normalization

To compensate for the effect of length, one common strategy is to length normalize the inputs. That is, treat the input as if it was effectively N characters long. Mathematically, this is done through exponentiation:

p(ws|c) = p(ws|c)N / length(c)

log p(ws|c) = [N / length(c)] log p(ws|c)

In effect, if length(c) > N, this pulls the estimate p(c|ws) closer to the category estimate p(c). If length(c) < N, it has the opposite effect, and actually increases the attenuation.

In particular, this will cause p(c|heem) = p(c|heen) for any m > 0 and n > 0.

In practice, we typically set the length norm to a fairly low number, like 5 or 10 or 20.

Constructing and Training Traditional Naive Bayes Classifiers

The full constructor for a trainable instance of naive Bayes is:

TradNaiveBayesClassifier(Set<String> categorySet, 
                         TokenizerFactory tokenizerFactory, 
                         double categoryPrior, 
                         double tokenInCategoryPrior, 
                         double lengthNorm);

where the arguments correspond to the parameters described above.

Training a traditional naive Bayes classifier is carried out in the usual way, through the class implementing the interface ClassificationHandler<CharSequence,Classification>, which means calling the naive Bayes method handle(CharSequence cSeq, Classification c) trains naive Bayes on the character sequence and classification.

Traditional naive Bayes also supports a weighted training method, where all frequencies will be multiplied by a count. And on top of that, a convenience method that lets it train from a conditional classification, using the weights in that classification as counts. It is this latter method that is used by EM.

The EM Algorithm

Here's the pseudo-code of the EM algorithm as copied from the javadoc of classify.TradNaiveBayes. The algorithm is initialized with an initial classifier, and loops until convergence:

 set lastClassifier to initialClassifier
 for (epoch = 0; epoch < maxEpochs; ++epoch) {
      create classifier using factory
      train classifier on supervised items
      for (x in unsupervised items) {
          compute p(c|x) with lastClassifier
          for (c in category) 
              train classifier on c weighted by p(c|x)
      }
      evaluate corpus and model probability under classifier
      set lastClassifier to classifier
      break if converged
 }
 return lastClassifier

In each epoch, we train a new classifier based on the supervised items, then on the output of the previous classifier, where training is weighted by the previous classifier's probability estimates. It's easy to see here that there's nothing naive-Bayes dependent about EM; it can be applied to any classifier that can compute p(c|x) and can be estimated from weighted labeled training data.

In practice, we require a corpus of labeled classified data, a corpus of unlabeled data, and a factory to generate a new trainable classifier in each epoch. We can also provide a convergence parameter such that if relative error isn't reduced by a pre-specified percentage in an epoch, the algorithm terminates. And we can also bound the maximum number of epochs.

EM in LingPipe's Naive Bayes

LingPipe's classify.TradNaiveBayes class provides two static methods for calculating EM estimates of naive Bayes classifiers. One provides an iterator over the results of each training epoch, but we will focus on the all-in-one method that does everything automatically. Its signature is a doozy:

static TradNaiveBayesClassifier
em(TradNaiveBayesClassifier initialClassifier, 
   Factory<TradNaiveBayesClassifier> classifierFactory, 

   Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
   Corpus<ObjectHandler<CharSequence>> unlabeledData,

   double minTokenCount, 

   int maxEpochs, 
   double minImprovement, 

   Reporter reporter); 

The first argument is the initial classifier. This is pulled out as a separate argument to enable soft clustering to be easily implemented, as well as to allow a crude form of annealing we will discuss shortly.

The second argument is a factory to create a fresh classifier in each epoch. Note that the tokenization scheme is rolled into the classifier itself -- it isn't an argument to the EM method.

The third and fourth arguments are corpora of labeled and unlabeled data respectively. Labeled data consists of classified character sequences, whereas unlabeled data is just character sequences.

The fifth argument is the minimum token count. Any token that does not occur that many times in the combined labeled and unlabeled data will be discarded.

The sixth and seventh arguments control convergence. The maximum number of epochs caps the number of epochs for which EM will run. The min improvement is the minimum relative improvement in each epoch to prevent termination. These can be effectively turned off by setting max epochs to Integer.MAX_VALUE and the minimum improvement to 0.0.

The last argument is an instance of com.aliasi.io.Reporter, which provides incremental feedback about the process of the algorithm. Reporters are like loggers, but without any discovery process -- they are fully programatically configurable. They may be configured with logging levels, and may write to files, standard output, or other streams.

Finally, the returned value is simply an instance of classify.TradNaiveBayesClassifier, which implements JointClassifier<CharSequence>. Traditional naive Bayes also implements LingPipe's util.Compilable interface, meaning that the resulting model may be compiled to a more efficient run-time representation. It may also be serialized as-is in its dynamically trainable form.

Creating Labeled and Unlabeled Corpora

The demo consists of two classes, one to do the EM estimate and set up all the requisite factory parameters, and another to create the corpus. We begin with the corpora.

The corpora are defined in the class TwentyNewsgroupsCorpus, which is in the file src/TwentyNewsgroupsCorpus.java.

The corpus code is not very interesting. The distribution has the posts arranged in two directories, one containing training data and the other test data. Within each directory are 20 subdirectories, one per newsgroup. Within each subdirectory is a set of files, one message per file.

Reading the Raw Data

All of the 20 newsgroups training data is read into two member variable maps from categories to arrays of texts for convenience, using the method read(File):

    final Map<String,String[]> mTrainingCatToTexts;
    final Map<String,String[]> mTestCatToTexts;

For instance, mTestCatToTexts.get("sci.med") returns an array where each entry is a test text for the newsgroup sci.med. The headers are stripped out, as is any message of two or fewer tokens.

Permuting the Data

In order to get different results across different runs, we apply permutations to the training data using this method (the test data does not need to be permuted):

    public void permuteInstances(Random random) {
        for (String[] xs : mTrainingCatToTexts.values())
            Arrays.permute(xs,random);
    }

Note that we've supplied a random number generator to LingPipe's array-utility permutation function (implementing Knuth's method), so that we can have reproducible results across runs by using the same seed.

Visiting Training and Testing Data

The corpus methods that require implementation visit the training and test data, which send the appropriate maps and bounds to a helper method:

    public void visitTrain(ObjectHandler<Classified<CharSequence>> handler) {
        visit(mTrainingCatToTexts,handler,mMaxSupervisedInstancesPerCategory);
    }

    public void visitTest(ObjectHandler<Classified<CharSequence>> handler) {
        visit(mTestCatToTexts,handler,Integer.MAX_VALUE);
    }

The variable mMaxSupervisedInstancesPerCategory is set for each evaluation to the required number of supervised training instances per category. For testing, we use the max integer value to return all instances. The visit method itself is a straightforward loop over the map:

    private static void visit(Map<String,String[]> catToItems,
                              ObjectHandler<Classified<CharSequence>> handler,
                              int maxItems) {
        for (Map.Entry<String,String[]> entry : catToItems.entrySet()) {
            String cat = entry.getKey();
            Classification c = new Classification(cat);
            String[] texts = entry.getValue();
            for (int i = 0; i < maxItems && i < texts.length; ++i) {
                Classified<CharSequence> classifiedText
                    = new Classified<CharSequence>(texts[i],c);
                handler.handle(classifiedText);
            }
        }
    }
}

Unlabeled Data Corpus

We return an unlabeled data corpus as an anonymous inner class by simply forgetting the categories:

    public Corpus<ObjectHandler<CharSequence>> unlabeledCorpus() {
        return new Corpus<ObjectHandler<CharSequence>>() {
            public void visitTest(ObjectHandler<CharSequence> handler) {
                throw new UnsupportedOperationException();
            }
            public void visitTrain(ObjectHandler<CharSequence> handler) {
                for (String[] texts : mTrainingCatToTexts.values())
                    for (int i = mMaxSupervisedInstancesPerCategory; 
                         i < texts.length; 
                         ++i)
                        handler.handle(texts[i]);
            }
        };
    }

}

Note that we start at the number of supervised instances, so that we don't duplicate data in the labeled and unlabeled corpora.

The EM Estimation Class

The hard part's in defining the corpora. It usually is. The evaluation code for a full random-sample eval, with 10 trials per number of training samples, is given in src/EmTwentyNewsgroups.java.

The Parameters

The code starts with a whole bunch of constant declarations:

    static final long RANDOM_SEED = 45L;

    static final int NUM_REPLICATIONS = 10;
    static final int MAX_EPOCHS = 20;

    static final double MIN_IMPROVEMENT = 0.0001;

    static final double CATEGORY_PRIOR = 0.005; 
    static final double TOKEN_IN_CATEGORY_PRIOR = 0.001;  
    static final double INITIAL_TOKEN_IN_CATEGORY_PRIOR = 0.1;
    static final double DOC_LENGTH_NORM = 9.0;
    static final double COUNT_MULTIPLIER = 1.0;
    static final double MIN_COUNT = 0.0001;

    static final TokenizerFactory TOKENIZER_FACTORY = tokenizerFactory();

Tokenization

The only interesting code behind this is in the tokenizer factory construction method (yes, a tokenizer factory factory method):

static TokenizerFactory tokenizerFactory() {
    TokenizerFactory factory = IndoEuropeanTokenizerFactory.INSTANCE;
    factory = new RegExFilteredTokenizerFactory(factory,Pattern.compile("\\p{Alpha}+"));
    factory = new LowerCaseTokenizerFactory(factory);
    factory = new EnglishStopTokenizerFactory(factory);
    return factory;
}

As usual, this sets up a set of filters on tokens. First, we use our basic Indo-European tokenizer, which provides a singleton instance. Then, we pass its output through a number of filters. First, we remove all non-alpha-numeric tokens by requiring every passed token to match the regular expression \p{Alpha}+ (with the backslash appropriately escaped, of course). Next, we convert tokens to lowercase, then remove ones in the English stop list. We didn't try any variations of this method, as it matched what others have done more or less.

Reporting

We set up a reporter to report to the system output stream all messages at the debug level or above, using the Latin1 (ISO-8859-1) character encoding:

        Reporter reporter = Reporters.stream(System.out,"ISO-8859-1").setLevel(LogLevel.DEBUG);

Constructing the Corpora

We construct the corpora using the constructor and a path to the unpacked data:

        final TwentyNewsgroupsCorpus corpus = new TwentyNewsgroupsCorpus(corpusPath);
        Corpus<TextHandler> unlabeledCorpus = corpus.unlabeledCorpus();

Trial Outer Loops

The outer loops of the code just run over different numbers of supervised items, evaluating each a number of times based on different permutations:

        for (int numSupervisedItems : new Integer[] {  1, 2, 4, 8, 16, 32, 64, 128, 256, 512 }) {
            corpus.setMaxSupervisedInstancesPerCategory(numSupervisedItems);
            double[] accs = new double[NUM_REPLICATIONS];
            double[] accsEm = new double[NUM_REPLICATIONS];
            for (int trial = 0; trial < NUM_REPLICATIONS; ++trial) {
                corpus.permuteInstances(random);
                ...

The accuracies and accuracies for EM are buffers to store the accuracies across trials so their means and deviations may be reported.

The Real Work

After all of this setup, the inner loop that does the estimation and evaluation is quite simple, though rather verbose:

...
TradNaiveBayesClassifier initialClassifier
    = new TradNaiveBayesClassifier(corpus.categorySet(),
                                   TOKENIZER_FACTORY,
                                   CATEGORY_PRIOR,
                                   INITIAL_TOKEN_IN_CATEGORY_PRIOR,
                                   DOC_LENGTH_NORM);

  Factory<TradNaiveBayesClassifier> classifierFactory 
    = new Factory<TradNaiveBayesClassifier>() {
        public TradNaiveBayesClassifier create() {
            return new TradNaiveBayesClassifier(corpus.categorySet(),
                                                TOKENIZER_FACTORY,
                                                CATEGORY_PRIOR,
                                                TOKEN_IN_CATEGORY_PRIOR,
                                                DOC_LENGTH_NORM);
        }};

TradNaiveBayesClassifier emClassifier
    = TradNaiveBayesClassifier.emTrain(initialClassifier,
                                       classifierFactory,
                                       corpus,
                                       unlabeledCorpus,
                                       MIN_COUNT,
                                       MAX_EPOCHS,
                                       MIN_IMPROVEMENT,
                                       reporter);
accs[trial] = eval(initialClassifier,corpus);
accsEm[trial] = eval(emClassifier,corpus);
...

First, we set up the initial classifier, which is itself a traditional naive Bayes classifier. It has its own token in category prior, which is larger than the one produced by the factory. This makes the predictions less attenuated to start, which is a kind of annealing strategy. For soft clustering, the initial classifier would perform random assignments.

Next, we create a classifier factory as an anonymous inner class. It just returns a new traditional naive Bayes classifier with a different prior than for the initial classifier.

Finally, we call the EM method with all of the parameters.

Then, we run evaluation, first on the initial classifier, then on the EM-trained classifier. The evaluation method just uses a LingPipe classifier evaluator to do its work; its set up to play nicely with corpora.

    static double eval(TradNaiveBayesClassifier classifier, 
                       Corpus<ObjectHandler<Classified<CharSequence>>> corpus)
        throws IOException, ClassNotFoundException {

        String[] categories = classifier.categorySet().toArray(new String[0]);
        Arrays.sort(categories);
        @SuppressWarnings("unchecked")
        JointClassifier<CharSequence> compiledClassifier
            = (JointClassifier<CharSequence>)
            AbstractExternalizable.compile(classifier);
        boolean storeInputs = false;
        JointClassifierEvaluator<CharSequence> evaluator
            = new JointClassifierEvaluator<CharSequence>(compiledClassifier,
                                                                        categories,
                                                                        storeInputs);
        corpus.visitTest(evaluator);
        return evaluator.confusionMatrix().totalAccuracy();
    }

Evaluation just generates the category array from the category set, which it sorts for convenience in output (not really used here). Then it compiles the classifier it was passed using the abstract externalizable utility method. Then it creates an evaluator and runs it over the test set in the corpus. Finally, it returns the accuracy from the evaluator's confusion matrix, which we put into the array accumulators.

When we get to the end of the loop per number of supervised training samples, we assign it into the array, then at the end of the loop, we print out means and deviations.

Running the Example

First, you need to download and unpack the corpus. Let's say the unpacked data is in a directory called TWENTY_NEWSGROUPS. Then the demo may be invoked through ant using:

ant -Dnewsgroups.path=TWENTY_NEWSGROUPS em

which returns:

CORPUS PATH=e:\data\20news\unpacked
DOC LENGTH NORM=9.0
CATEGORY PRIOR=0.0050
TOKEN IN CATEGORY PRIOR=0.0010
INITIAL TOKEN IN CATEGORY PRIOR=0.1
NUM REPS=10
MAX EPOCHS=20
RANDOM SEED=45

alt.atheism #train=480 #test=319
comp.graphics #train=582 #test=389
comp.os.ms-windows.misc #train=590 #test=393
comp.sys.ibm.pc.hardware #train=588 #test=391
comp.sys.mac.hardware #train=576 #test=384
comp.windows.x #train=588 #test=391
misc.forsale #train=578 #test=381
rec.autos #train=592 #test=396
rec.motorcycles #train=597 #test=398
rec.sport.baseball #train=596 #test=397
rec.sport.hockey #train=600 #test=399
sci.crypt #train=595 #test=396
sci.electronics #train=591 #test=393
sci.med #train=594 #test=395
sci.space #train=592 #test=394
soc.religion.christian #train=599 #test=398
talk.politics.guns #train=546 #test=364
talk.politics.mideast #train=563 #test=376
talk.politics.misc #train=465 #test=310
talk.religion.misc #train=376 #test=251
TOTALS: #train=11288 #test=7515 #combined=18803


SUPERVISED DOCS/CAT=1
TRIAL=0
      :12 epoch=   0   dataLogProb=    -1237465.70   modelLogProb=    15035211.18   logProb=    13797745.48   diff=            NaN
      :24 epoch=   1   dataLogProb=    -1231570.90   modelLogProb=    15123510.94   logProb=    13891940.04   diff= 0.006803584231
      :37 epoch=   2   dataLogProb=    -1216209.52   modelLogProb=    15529135.24   logProb=    14312925.72   diff= 0.029851989631
      :49 epoch=   3   dataLogProb=    -1199468.53   modelLogProb=    16320003.21   logProb=    15120534.68   diff= 0.054876929245
...
     3:25 epoch=  16   dataLogProb=    -1176298.68   modelLogProb=    17980646.97   logProb=    16804348.29   diff= 0.000162049813
     3:37 epoch=  17   dataLogProb=    -1176195.94   modelLogProb=    17982292.01   logProb=    16806096.07   diff= 0.000104002156
     3:49 epoch=  18   dataLogProb=    -1176095.45   modelLogProb=    17983720.34   logProb=    16807624.89   diff= 0.000090963924
     3:49 Converged
ACC=0.129   EM ACC=0.377

...
TRIAL=9
    31:59 epoch=   0   dataLogProb=    -1237448.69   modelLogProb=    15030709.70   logProb=    13793261.01   diff=            NaN
    32:11 epoch=   1   dataLogProb=    -1230837.52   modelLogProb=    15115282.06   logProb=    13884444.54   diff= 0.006588951110
...
    35:11 epoch=  16   dataLogProb=    -1176545.63   modelLogProb=    17810014.34   logProb=    16633468.71   diff= 0.000041518289
    35:11 Converged
ACC=0.108   EM ACC=0.226

     ---------------------
#Sup=   1  Supervised mean(acc)=0.130 sd(acc)=0.009   EM mean(acc)=0.389 sd(acc)=0.068          35:18

...
...

SUPERVISED DOCS/CAT=512
TRIAL=0
  4:06:27 epoch=   0   dataLogProb=    -1172021.48   modelLogProb=    18066945.68   logProb=    16894924.20   diff=            NaN
  4:06:38 epoch=   1   dataLogProb=    -1171420.14   modelLogProb=    18226716.74   logProb=    17055296.60   diff= 0.009447502378
  4:06:49 epoch=   2   dataLogProb=    -1171342.28   modelLogProb=    18242581.15   logProb=    17071238.88   diff= 0.000934303845
  4:07:00 epoch=   3   dataLogProb=    -1171323.18   modelLogProb=    18246273.05   logProb=    17074949.87   diff= 0.000217359169
  4:07:11 epoch=   4   dataLogProb=    -1171315.29   modelLogProb=    18247489.49   logProb=    17076174.21   diff= 0.000071701103
  4:07:11 Converged
ACC=0.800   EM ACC=0.807

...

  4:18:08 epoch=   5   dataLogProb=    -1171301.28   modelLogProb=    18248274.13   logProb=    17076972.84   diff= 0.000055576147
  4:18:08 Converged
ACC=0.803   EM ACC=0.810

     ---------------------
#Sup= 512  Supervised mean(acc)=0.802 sd(acc)=0.001   EM mean(acc)=0.808 sd(acc)=0.001        4:18:17

The entire run took a little over four hours.

For each epoch, we are reporting total elapsed time, the epoch ID, the log probability of the data Σi log p(ws|c) + log p(c), and the log probablity of the model log p(θ), and the sum under the final column. It is this sum that's maximized by EM (error is negative log probability, which is minimized). Note that the model log prob is actually a density, which is why it can be positive. The relative difference in error between epochs is also reported.

Note that as the amount of supervised data goes up, and the amount of unsupervised data goes down, the number of epochs required to reach convergence also goes down.

We collect these results into more readable form in the next section.

Results on 20 Newsgroups

The 20 Newsgroups corpus is a widely used training and test set for natural language classification. It consists of a number of e-mail messages from 20 different newsgroups, some of which are on closely related topics and some on very diverse topics.

Jason Rennie maintains the data set at MIT:

We used the recommended by-date version of the corpus, which splits the articles into a standard train and test set, with posts in the test set being made chronologically later than the ones in the train set. The corpus contains 18846 documents, which include headers. Here's an example from rec.sports.baseball:

     From: ez027993@dale.ucdavis.edu (Gary The Burgermeister Huckabay)
     Subject: Bill James Player Rating Book 1993.
     Organization: Harold Brooks Duck L'Orange Club, Ltd.
     Lines: 26    

    (Dave 'This has never happened to me before' Kirsch) writes:
    >  Correction: "Nied was the only player identified in this book as a grade A
    >prospect who was exposed to the draft..", according to Bill James in the
    >'Stop the Presses' section preceding his player evaluations. He valued Nied
    >at $21, and said that Nied's value does not increase significantly as a
    >result of his selection (although he did catch a break getting away from the
    >strongest rotation in baseball). 
    
    I thought Bill James' latest book completely and totally sucked.  I bought
    it, but will not purchase anything of his ever again without THOROUGHLY
    looking at it first.  What tripe.
    
    The book is inconsistent, and filled with selective analysis.  James
    claims to be looking forward, and then makes some absolutely bizarre
    statements of value.  Not only that, but I got the impression he
    probably glanced at the book for about an hour before he put his name
    on it. 
    
    To say I was disappointed is a grand understatement.
    
    
    -- 
    *     Gary Huckabay      * Kevin Kerr: The Al Feldstein of the mid-90's! *
    * "A living argument for * If there's anything we love more than a huge  *
    *  existence of parallel * .sig, it's someone quoting 100 lines to add   *
    *       universes."      * 3 or 4 new ones.  And consecutive posts, too. *

As you can see, there are headers (in light grey on grey), quotes from previous posts, and signatures. We stripped off the headers and removed documents with fewer than three tokens, leaving us with a total of 18803 documents, divided among the 20 newsgroups as follows:

NEWSGROUP              #TRAIN  #TEST

alt.atheism               480    319
comp.graphics             582    389
comp.os.ms-windows.misc   590    393
comp.sys.ibm.pc.hardware  588    391
comp.sys.mac.hardware     576    384
comp.windows.x            588    391
misc.forsale              578    381
rec.autos                 592    396
rec.motorcycles           597    398
rec.sport.baseball        596    397
rec.sport.hockey          600    399
sci.crypt                 595    396
sci.electronics           591    393
sci.med                   594    395
sci.space                 592    394
soc.religion.christian    599    398
talk.politics.guns        546    364
talk.politics.mideast     563    376
talk.politics.misc        465    310
talk.religion.misc        376    251

TOTALS:                 11288   7515

The results are presented in the following table, which lists the number of supervised documents per category, followed by the accuracies of the fully supervised classifier trained only on the supervised documents, followed by the accuracy achieved by EM on the combination of the supervised and unsupervised documents.

Supervised and Semi-Supervised Results on 20 Newsgroups
#SupervisedSupervised AccuracyEM Accuracy
10.130 +/- 0.0090.389 +/- 0.068
20.183 +/- 0.0210.483 +/- 0.038
40.239 +/- 0.0160.615 +/- 0.027
80.357 +/- 0.0120.661 +/- 0.020
160.481 +/- 0.0120.712 +/- 0.012
320.581 +/- 0.0110.735 +/- 0.007
640.677 +/- 0.0070.755 +/- 0.006
1280.735 +/- 0.0040.770 +/- 0.002
2560.778 +/- 0.0030.789 +/- 0.002
5120.802 +/- 0.0010.808 +/- 0.001

Note that the gap between semi-supervised and fully supervised closes as more supervised documents are added. In theory, with enough supervised documents, the estimator should converge and it shouldn't be possible to get better performance by adding unlabeled data.

Also note that the number of unsupervised documents available decreases as more and more of the training set is used on the supervised side. In typical applications, the amount of unlabeled data is effectively unbounded in the sense that we don't have enough computing power to use all that we have. (Though, having said that, EM is so easily parallelizable it's called "embarassingly parallel", so with large clusters, it can be scaled out quite easily.) It'd be relatively easy to download hundreds of thousands of posts to these rather popular newsgroups, for instance, using Google Groups, which for instance, reports 175,508 articles in the archive for the newsgroup talk.politics.guns (on 16 April 2009).

These results are substantially better for low numbers of supervised documents than that reported by Nigam, McCallum and Mitchell (2006) [see references]. Although not reported in their paper, Kamal Nigam told us that he used a length normalization of around 100. When we increase length norm to that level, our results look more like theirs. We also found that having a more diffuse initial classifier (higher prior count) led to much better performance.

References