import com.aliasi.cluster.LatentDirichletAllocation; import com.aliasi.corpus.ObjectHandler; import com.aliasi.symbol.SymbolTable; import com.aliasi.util.ObjectToCounterMap; import com.aliasi.util.Strings; import java.util.List; public class LdaReportingHandler implements ObjectHandler { private final SymbolTable mSymbolTable; private final long mStartTime; LdaReportingHandler(SymbolTable symbolTable) { mSymbolTable = symbolTable; mStartTime = System.currentTimeMillis(); } public void handle(LatentDirichletAllocation.GibbsSample sample) { System.out.printf("Epoch=%3d elapsed time=%s\n", sample.epoch(), Strings.msToString(System.currentTimeMillis() - mStartTime)); if ((sample.epoch() % 10) == 0) { double corpusLog2Prob = sample.corpusLog2Probability(); System.out.println(" log2 p(corpus|phi,theta)=" + corpusLog2Prob + " token cross-entropy rate=" + (-corpusLog2Prob/sample.numTokens())); } } void fullReport(LatentDirichletAllocation.GibbsSample sample, int maxWordsPerTopic, int maxTopicsPerDoc, boolean reportTokens) { System.out.println("\nFull Report"); int numTopics = sample.numTopics(); int numWords = sample.numWords(); int numDocs = sample.numDocuments(); int numTokens = sample.numTokens(); System.out.println("epoch=" + sample.epoch()); System.out.println("numDocs=" + numDocs); System.out.println("numTokens=" + numTokens); System.out.println("numWords=" + numWords); System.out.println("numTopics=" + numTopics); for (int topic = 0; topic < numTopics; ++topic) { int topicCount = sample.topicCount(topic); ObjectToCounterMap counter = new ObjectToCounterMap(); for (int word = 0; word < numWords; ++word) counter.set(Integer.valueOf(word),sample.topicWordCount(topic,word)); List topWords = counter.keysOrderedByCountList(); System.out.println("\nTOPIC " + topic + " (total count=" + topicCount + ")"); System.out.println("SYMBOL WORD COUNT PROB Z"); System.out.println("--------------------------------------------------"); for (int rank = 0; rank < maxWordsPerTopic && rank < topWords.size(); ++rank) { int wordId = topWords.get(rank); String word = mSymbolTable.idToSymbol(wordId); int wordCount = sample.wordCount(wordId); int topicWordCount = sample.topicWordCount(topic,wordId); double topicWordProb = sample.topicWordProb(topic,wordId); double z = binomialZ(topicWordCount, topicCount, wordCount, numTokens); System.out.printf("%6d %15s %7d %4.3f %8.1f\n", wordId, word, topicWordCount, topicWordProb, z); } } for (int doc = 0; doc < numDocs; ++doc) { int docCount = 0; for (int topic = 0; topic < numTopics; ++topic) docCount += sample.documentTopicCount(doc,topic); ObjectToCounterMap counter = new ObjectToCounterMap(); for (int topic = 0; topic < numTopics; ++topic) counter.set(Integer.valueOf(topic),sample.documentTopicCount(doc,topic)); List topTopics = counter.keysOrderedByCountList(); System.out.println("\nDOC " + doc); System.out.println("TOPIC COUNT PROB"); System.out.println("----------------------"); for (int rank = 0; rank < topTopics.size() && rank < maxTopicsPerDoc; ++rank) { int topic = topTopics.get(rank); int docTopicCount = sample.documentTopicCount(doc,topic); double docTopicPrior = sample.documentTopicPrior(); double docTopicProb = (sample.documentTopicCount(doc,topic) + docTopicPrior) / (docCount + numTopics * docTopicPrior); System.out.printf("%5d %7d %4.3f\n", topic, docTopicCount, docTopicProb); } System.out.println(); if (!reportTokens) continue; int numDocTokens = sample.documentLength(doc); for (int tok = 0; tok < numDocTokens; ++tok) { int symbol = sample.word(doc,tok); short topic = sample.topicSample(doc,tok); String word = mSymbolTable.idToSymbol(symbol); System.out.print(word + "(" + topic + ") "); } System.out.println(); } } static double binomialZ(double wordCountInDoc, double wordsInDoc, double wordCountinCorpus, double wordsInCorpus) { double pCorpus = wordCountinCorpus / wordsInCorpus; double var = wordsInCorpus * pCorpus * (1 - pCorpus); double dev = Math.sqrt(var); double expected = wordsInDoc * pCorpus; double z = (wordCountInDoc - expected) / dev; return z; } }