import com.aliasi.classify.Classification; import com.aliasi.classify.Classifier; import com.aliasi.classify.ConditionalClassification; import com.aliasi.classify.DynamicLMClassifier; import com.aliasi.classify.JointClassification; import com.aliasi.classify.KnnClassifier; import com.aliasi.classify.NaiveBayesClassifier; import com.aliasi.classify.ScoredClassification; import com.aliasi.classify.TfIdfClassifierTrainer; import com.aliasi.corpus.ClassificationHandler; import com.aliasi.io.FileLineReader; import com.aliasi.lm.LanguageModel; import com.aliasi.lm.NGramBoundaryLM; import com.aliasi.lm.NGramProcessLM; import com.aliasi.lm.TokenizedLM; import com.aliasi.matrix.CosineDistance; import com.aliasi.matrix.TaxicabDistance; import com.aliasi.util.AbstractExternalizable; import com.aliasi.util.Compilable; import com.aliasi.util.FeatureExtractor; import com.aliasi.util.Files; import com.aliasi.util.Proximity; import com.aliasi.tokenizer.EnglishStopTokenizerFactory; import com.aliasi.tokenizer.LowerCaseTokenizerFactory; import com.aliasi.tokenizer.NGramTokenizerFactory; import com.aliasi.tokenizer.PorterStemmerTokenizerFactory; import com.aliasi.tokenizer.RegExTokenizerFactory; import com.aliasi.tokenizer.TokenFeatureExtractor; import com.aliasi.tokenizer.Tokenizer; import com.aliasi.tokenizer.TokenizerFactory; import java.io.BufferedWriter; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; public class Senseval3 { // label for unknown sense static final String UNKNOWN_SENSE = "U"; static SenseEvalDict mDict; static int sClassifierNumber = -1; static final TokenizerFactory NGRAM_TOKENIZER_FACTORY = new NGramTokenizerFactory(4,6); static final TokenizerFactory SPACE_TOKENIZER_FACTORY = new RegExTokenizerFactory("\\S+"); static final TokenizerFactory NORM_TOKENIZER_FACTORY = normTokenizerFactory(); static TokenizerFactory normTokenizerFactory() { TokenizerFactory factory = SPACE_TOKENIZER_FACTORY; factory = new LowerCaseTokenizerFactory(factory); // factory = EnglishStopTokenizerFactory(factory); // factory = PorterStemmerTokenizerFactory(factory); return factory; } public static void main(String[] args) throws ClassNotFoundException, IOException { File dictFile = new File(args[0]); File trainFile = new File(args[1]); File testFile = new File(args[2]); File responseFile = new File(args[3]); sClassifierNumber = Integer.valueOf(args[4]); System.out.println("Dictionary File=" + dictFile.getCanonicalPath()); System.out.println("Training File=" + trainFile.getCanonicalPath()); System.out.println("Testing File=" + testFile.getCanonicalPath()); System.out.println("System Response File=" + responseFile.getCanonicalPath()); System.out.println("classifier id=" + sClassifierNumber); System.out.println(); System.out.println("Reading Dictionary."); SenseEvalDict dict = new SenseEvalDict(dictFile); System.out.println(" #entries=" + dict.size()); System.out.println(); System.out.println("Reading Training Data."); TrainingData trainingData = new TrainingData(trainFile); System.out.println(" #training words=" + trainingData.size()); System.out.println(); System.out.println("Reading Test Data."); TestData testData = new TestData(testFile); System.out.println(" #test cases=" + testData.mWordsPlusCats.size()); System.out.println(); System.out.println("Training and Compiling Models."); SenseEvalModel model = new SenseEvalModel(dict,trainingData); System.out.println(" finished training."); System.out.println(); System.out.println("Running Model over Test Data."); respond(model,testData,responseFile); System.out.println(" finished test data."); System.out.println(); System.out.println("FINISHED."); } static ClassificationHandler createClassifierTrainer(String[] senseIds) { switch (sClassifierNumber) { case 0: // DEFAULT CHARACTER LM CLASSIFIER return DynamicLMClassifier.createNGramProcess(senseIds,5); case 1: // CONFIGURABLE CHARACTER LM CLASSIFIER LanguageModel.Dynamic[] lms5 = new LanguageModel.Dynamic[senseIds.length]; for (int i = 0; i < lms5.length; ++i) lms5[i] = new NGramProcessLM(6, // n-gram 128, // num chars 1.0); // interpolation ratio return new DynamicLMClassifier(senseIds,lms5); case 2: // DEFAULT NAIVE BAYES CLASSIFIER return new NaiveBayesClassifier(senseIds,NORM_TOKENIZER_FACTORY); case 3: // DEFAULT TOKEN UNIGRAM LM CLASSIFIER return DynamicLMClassifier.createTokenized(senseIds, NORM_TOKENIZER_FACTORY, 1); case 4: // DEFAULT TOKEN BIGRAM LM CLASSIFIER return DynamicLMClassifier.createTokenized(senseIds, NORM_TOKENIZER_FACTORY, 2); case 5: // CONFIGURABLE TOKENIZED LM CLASSIFIER W. CHARACTER BOUNDARY LM SMOOTHING LanguageModel.Dynamic[] lms2 = new LanguageModel.Dynamic[senseIds.length]; for (int i = 0; i < lms2.length; ++i) lms2[i] = new TokenizedLM(NORM_TOKENIZER_FACTORY, 3, // n-gram length new NGramBoundaryLM(4,128,0.5,'\uFFFF'), new NGramBoundaryLM(4,128,0.5,'\uFFFF'), 0.1); // interpolation param return new DynamicLMClassifier(senseIds,lms2); case 6: // TF-IDF CLASSIFIER FeatureExtractor featureExtractor5 = new TokenFeatureExtractor(SPACE_TOKENIZER_FACTORY); return new TfIdfClassifierTrainer(featureExtractor5); case 7: // K-NEAREST NEIGHBORS DEFAULT CLASSIFIER (EUCLIDEAN DISTANCE) FeatureExtractor featureExtractor7 = new TokenFeatureExtractor(SPACE_TOKENIZER_FACTORY); return new KnnClassifier(featureExtractor7, 16); // num neighbors to average case 8: // K-NEAREST NEIGHBORS DEFAULT CLASSIFIER (COSINE DISTANCE) FeatureExtractor featureExtractor8 = new TokenFeatureExtractor(NGRAM_TOKENIZER_FACTORY); return new KnnClassifier(featureExtractor8, 5, // num neighbors to average new CosineDistance(), true); default: String msg = "classifier id must be between 0 and 3." + " found id=" + sClassifierNumber; throw new IllegalArgumentException(msg); } } static void respond(SenseEvalModel model, TestData testData, File file) throws IOException { FileOutputStream fileOut = new FileOutputStream(file); OutputStreamWriter osWriter = new OutputStreamWriter(fileOut,"ISO-8859-1"); BufferedWriter bufWriter = new BufferedWriter(osWriter); for (int i = 0; i < testData.mWordsPlusCats.size(); ++i) { String wordPlusCat = testData.mWordsPlusCats.get(i); Classifier classifier = model.get(wordPlusCat); String instanceId = testData.mInstanceIds.get(i); String textToClassify = testData.mTextsToClassify.get(i); Classification classification = classifier.classify(textToClassify); bufWriter.write(wordPlusCat + " " + wordPlusCat + ".bnc." + instanceId); if (classification instanceof ConditionalClassification) { ConditionalClassification condClassification = (ConditionalClassification) classification; for (int rank = 0; rank < condClassification.size(); ++rank) { int conditionalProb = (int) java.lang.Math.round(1000.0 * condClassification.conditionalProbability(rank)); if (rank > 0 && conditionalProb < 1) break; String category = condClassification.category(rank); bufWriter.write(" " + category + "/" + conditionalProb); } } else { bufWriter.write(" " + classification.bestCategory()); } bufWriter.write("\n"); } bufWriter.close(); } static int seek(String lineStartString, String[] lines, int pos) { if (pos == -1) return -1; for ( ; pos < lines.length; ++pos) if (lines[pos].startsWith(lineStartString)) return pos; return -1; } static class TestData { List mWordsPlusCats = new ArrayList(); List mInstanceIds = new ArrayList(); List mTextsToClassify = new ArrayList(); TestData(File file) throws IOException { String[] lines = FileLineReader.readLineArray(file,"ISO-8859-1"); int pos = 0; while ((pos = seek("= 0) { pos = parse(lines,pos); } } int parse(String[] lines, int pos) { String id = extractAttribute("id",lines[pos]); int endIndex = id.indexOf('.',id.indexOf('.')+1); String wordPlusCat = id.substring(0,endIndex); int startIndex = id.lastIndexOf('.') + 1; String instanceId = id.substring(startIndex); String textToClassify = lines[pos+2]; mWordsPlusCats.add(wordPlusCat); mInstanceIds.add(instanceId); mTextsToClassify.add(textToClassify); return pos + 2; } } static class SenseEvalModel extends HashMap> { static final long serialVersionUID = -6343177898894927184L; SenseEvalModel(SenseEvalDict dict, TrainingData trainingData) throws ClassNotFoundException, IOException { for (String wordPlusCat : trainingData.keySet()) { Map> senseToTextList = trainingData.get(wordPlusCat); String[] senseIds = senseToTextList.keySet().toArray(new String[0]); System.out.println(" " + wordPlusCat + " [" + senseIds.length + " senses]"); ClassificationHandler trainer = createClassifierTrainer(senseIds); for (String senseId : senseToTextList.keySet()) { Classification classificationForSenseId = new Classification(senseId); List trainingTextList = senseToTextList.get(senseId); for (String trainingText : trainingTextList) trainer.handle(trainingText, classificationForSenseId); } @SuppressWarnings("unchecked") Classifier classifier = (Classifier) AbstractExternalizable.compile((Compilable)trainer); put(wordPlusCat,classifier); } } } // wordPlusCat -> senseId -> { training-text } static class TrainingData extends HashMap>> { static final long serialVersionUID = 8094465899104433829L; public TrainingData(File file) throws IOException { String[] lines = FileLineReader.readLineArray(file,"ISO-8859-1"); for (int pos = 0; (pos = seek("= 0; ) pos = trainLexElt(lines,pos); } int trainLexElt(String[] lines, int pos) { String wordPlusCat = extractAttribute("item",lines[pos++]); while (pos < lines.length) { if (lines[pos].startsWith(" idSet = new HashSet(); for ( ; lines[pos].startsWith(" idSet) { for (String senseId : idSet) { if (senseId.equals(UNKNOWN_SENSE)) { continue; } Map> senseToTextListMap = get(wordPlusCat); if (senseToTextListMap == null) { senseToTextListMap = new HashMap>(); put(wordPlusCat,senseToTextListMap); } List trainingTextList = senseToTextListMap.get(senseId); if (trainingTextList == null) { trainingTextList = new ArrayList(); senseToTextListMap.put(senseId,trainingTextList); } trainingTextList.add(trainingText); } } } // wordpluscat -> sense static class SenseEvalDict extends HashMap { static final long serialVersionUID = -8332185573089002878L; SenseEvalDict(File file) throws IOException { String[] lines = FileLineReader.readLineArray(file,"ISO-8859-1"); for (int pos = 0; (pos = seek("= 0; ) pos = readDictionary(lines,pos); } int readDictionary(String[] lines, int pos) { String wordPlusCat = extractAttribute("item",lines[pos]); List senseList = new ArrayList(); while (lines[++pos].startsWith("toArray(new Sense[senseList.size()]); put(wordPlusCat,senses); return pos; } } static class Sense { String mId; String mSource; String mSynset; String mGloss; Sense(String line) { mId = extractAttribute("id",line); mSource = extractAttribute("source",line); mSynset = extractAttribute("synset",line); mGloss = extractAttribute("gloss",line); } public String toString() { return "ID=" + mId + " SRC=" + mSource + " SYNSET=" + mSynset + " GLOSS=" + mGloss; } } static String extractAttribute(String att, String line) { int start = line.indexOf(att + "=") + att.length()+2; int end = line.indexOf('"',start); return line.substring(start,end); } }