import com.aliasi.classify.ClassifierEvaluator; import com.aliasi.classify.Classification; import com.aliasi.classify.ConditionalClassification; import com.aliasi.classify.ConfusionMatrix; import com.aliasi.classify.LogisticRegressionClassifier; import com.aliasi.classify.XValidatingClassificationCorpus; import com.aliasi.stats.AnnealingSchedule; import com.aliasi.stats.RegressionPrior; import com.aliasi.tokenizer.RegExTokenizerFactory; import com.aliasi.tokenizer.TokenizerFactory; import com.aliasi.tokenizer.TokenFeatureExtractor; import com.aliasi.util.AbstractExternalizable; import com.aliasi.util.FeatureExtractor; import com.aliasi.util.Files; import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import java.util.Random; public class TextClassificationDemo { static final File TRAINING_DIR = new File("../../data/fourNewsGroups/4news-train"); static final File TESTING_DIR = new File("../../data/fourNewsGroups/4news-test"); static final String[] CATEGORIES = { "soc.religion.christian", "talk.religion.misc", "alt.atheism", "misc.forsale" }; public static void main(String[] args) throws Exception { if (!TRAINING_DIR.isDirectory()) { System.out.println("Could not find data directory=" + TRAINING_DIR); System.out.println("Have you unpacked 4 newsgroups from $LINGPIPE/demos/data?"); return; } PrintWriter progressWriter = new PrintWriter(System.out,true); progressWriter.println("Reading data."); int numFolds = 4; XValidatingClassificationCorpus corpus = new XValidatingClassificationCorpus(numFolds); for (String category : CATEGORIES) { Classification c = new Classification(category); for (File trainingFile : new File(TRAINING_DIR,category).listFiles()) { String text = Files.readFromFile(trainingFile); corpus.handle(text,c); } for (File trainingFile : new File(TESTING_DIR,category).listFiles()) { String text = Files.readFromFile(trainingFile); corpus.handle(text,c); } } progressWriter.println("Num instances=" + corpus.numInstances() + "."); progressWriter.println("Permuting corpus."); corpus.permuteCorpus(new Random(7117)); // destroys runs of categories progressWriter.println("\nEVALUATING FOLDS\n"); TokenizerFactory tokenizerFactory = new RegExTokenizerFactory("\\p{L}+|\\d+"); // letter+ | digit+ FeatureExtractor featureExtractor = new TokenFeatureExtractor(tokenizerFactory); int minFeatureCount = 5; boolean addInterceptFeature = true; boolean noninformativeIntercept = true; double priorVariance = 0.5; RegressionPrior prior = RegressionPrior.laplace(priorVariance,noninformativeIntercept); AnnealingSchedule annealingSchedule = AnnealingSchedule.exponential(0.002,0.9975); double minImprovement = 0.0000001; int minEpochs = 100; int maxEpochs = 1000; for (int fold = 0; fold < numFolds; ++fold) { corpus.setFold(fold); LogisticRegressionClassifier classifier = LogisticRegressionClassifier.train(featureExtractor, corpus, minFeatureCount, addInterceptFeature, prior, annealingSchedule, minImprovement, minEpochs, maxEpochs, progressWriter); progressWriter.println("\nCLASSIFIER & FEATURES\n"); progressWriter.println(classifier); progressWriter.println("\nEVALUATION\n"); ClassifierEvaluator evaluator = new ClassifierEvaluator(classifier,CATEGORIES); corpus.visitTest(evaluator); progressWriter.printf("FOLD=%5d ACC=%4.2f +/-%4.2f\n", fold, evaluator.confusionMatrix().totalAccuracy(), evaluator.confusionMatrix().confidence95()); } } // 108.6 .001/.999 // -107.9336 .002/.9975 }