001/* 002 * (C) Copyright 2009 Nuxeo SA (http://nuxeo.com/) and others. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 * 016 * Contributors: 017 * Olivier Grisel 018 */ 019 020package org.nuxeo.ecm.platform.categorization.categorizer.tfidf; 021 022import java.io.BufferedReader; 023import java.io.File; 024import java.io.FileInputStream; 025import java.io.FileNotFoundException; 026import java.io.FileOutputStream; 027import java.io.IOException; 028import java.io.InputStream; 029import java.io.InputStreamReader; 030import java.io.ObjectInputStream; 031import java.io.ObjectOutputStream; 032import java.io.OutputStream; 033import java.io.Serializable; 034import java.nio.charset.Charset; 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.Comparator; 038import java.util.LinkedHashMap; 039import java.util.LinkedList; 040import java.util.List; 041import java.util.Map; 042import java.util.Set; 043import java.util.SortedMap; 044import java.util.TreeMap; 045import java.util.TreeSet; 046import java.util.concurrent.ConcurrentHashMap; 047import java.util.zip.GZIPInputStream; 048import java.util.zip.GZIPOutputStream; 049 050import org.apache.commons.logging.Log; 051import org.apache.commons.logging.LogFactory; 052import org.apache.lucene.analysis.Analyzer; 053import org.apache.lucene.analysis.TokenStream; 054import org.apache.lucene.analysis.standard.StandardAnalyzer; 055import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 056import org.apache.lucene.util.Version; 057import org.nuxeo.ecm.platform.categorization.service.Categorizer; 058 059/** 060 * Maintains a map of TF counts vectors in memory (just for a few reference documents or topics) along with the common 061 * IDF estimate of all previously seen text content. 062 * <p> 063 * See: http://en.wikipedia.org/wiki/Tfidf 064 * <p> 065 * Classification is then achieved using the cosine similarity between the TF-IDF of the document to classify and the 066 * registered topics. 067 */ 068public class TfIdfCategorizer extends PrimitiveVectorHelper implements Categorizer, Serializable { 069 070 private static final long serialVersionUID = 1L; 071 072 public static final Log log = LogFactory.getLog(TfIdfCategorizer.class); 073 074 protected final Set<String> topicNames = new TreeSet<String>(); 075 076 protected final Map<String, Object> topicTermCount = new ConcurrentHashMap<String, Object>(); 077 078 protected final Map<String, Object> cachedTopicTfIdf = new ConcurrentHashMap<String, Object>(); 079 080 protected final Map<String, Float> cachedTopicTfIdfNorm = new ConcurrentHashMap<String, Float>(); 081 082 protected long[] allTermCounts; 083 084 protected final int dim; 085 086 protected float[] cachedIdf; 087 088 protected long totalTermCount = 0; 089 090 protected final HashingVectorizer vectorizer; 091 092 protected transient Analyzer analyzer; 093 094 protected Double ratioOverMedian = 3.0; 095 096 protected boolean updateDisabled = false; 097 098 public TfIdfCategorizer() { 099 this(524288); // 2 ** 19 100 } 101 102 public TfIdfCategorizer(int dim) { 103 this.dim = dim; 104 allTermCounts = new long[dim]; 105 vectorizer = new HashingVectorizer().dimension(dim); 106 } 107 108 public HashingVectorizer getVectorizer() { 109 return vectorizer; 110 } 111 112 public Analyzer getAnalyzer() { 113 if (analyzer == null) { 114 // TODO: make it possible to configure the stop words 115 analyzer = new StandardAnalyzer(Version.LUCENE_47); 116 } 117 return analyzer; 118 } 119 120 /** 121 * Precompute all the TF-IDF vectors and unload the original count vectors to spare some memory. Updates won't be 122 * possible any more. 123 */ 124 public synchronized void disableUpdate() { 125 updateDisabled = true; 126 // compute all the frequencies 127 getIdf(); 128 for (String topicName : topicNames) { 129 tfidf(topicName); 130 tfidfNorm(topicName); 131 } 132 // upload the count vectors 133 topicTermCount.clear(); 134 allTermCounts = null; 135 } 136 137 /** 138 * Update the model to take into account the statistical properties of a document that is known to be relevant to 139 * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see 140 * #getSimilarities(List) 141 * 142 * @param topicName the name of the document topic or category 143 * @param terms the list of document tokens (use a lucene analyzer to extract theme for instance) 144 */ 145 public void update(String topicName, List<String> terms) { 146 if (updateDisabled) { 147 throw new IllegalStateException("updates are no longer authorized once #disableUpdate has been called"); 148 } 149 long[] counts = vectorizer.count(terms); 150 totalTermCount += sum(counts); 151 long[] topicCounts = (long[]) topicTermCount.get(topicName); 152 if (topicCounts == null) { 153 topicCounts = new long[dim]; 154 topicTermCount.put(topicName, topicCounts); 155 topicNames.add(topicName); 156 } 157 add(topicCounts, counts); 158 add(allTermCounts, counts); 159 invalidateCache(topicName); 160 } 161 162 /** 163 * Update the model to take into account the statistical properties of a document that is known to be relevant to 164 * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see 165 * #getSimilarities(List) 166 * 167 * @param topicName the name of the document topic or category 168 * @param textContent textual content to be tokenized and analyzed 169 */ 170 public void update(String topicName, String textContent) { 171 update(topicName, tokenize(textContent)); 172 } 173 174 protected void invalidateCache(String topicName) { 175 cachedTopicTfIdf.remove(topicName); 176 cachedTopicTfIdfNorm.remove(topicName); 177 cachedIdf = null; 178 } 179 180 protected void invalidateCache() { 181 for (String topicName : topicNames) { 182 invalidateCache(topicName); 183 } 184 } 185 186 /** 187 * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the 188 * document given by a list of tokens. 189 * 190 * @param terms a tokenized document. 191 * @return a map of topic names to float values from 0 to 1 sorted by reverse value. 192 */ 193 public Map<String, Float> getSimilarities(List<String> terms) { 194 SortedMap<String, Float> similarities = new TreeMap<String, Float>(); 195 196 float[] tfidf1 = getTfIdf(vectorizer.count(terms)); 197 float norm1 = normOf(tfidf1); 198 if (norm1 == 0) { 199 return similarities; 200 } 201 202 for (String topicName : topicNames) { 203 float[] tfidf2 = tfidf(topicName); 204 float norm2 = tfidfNorm(topicName); 205 if (norm2 == 0) { 206 continue; 207 } 208 similarities.put(topicName, dot(tfidf1, tfidf2) / (norm1 * norm2)); 209 } 210 return sortByDecreasingValue(similarities); 211 } 212 213 /** 214 * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the 215 * document. 216 * 217 * @param the document to be tokenized and analyzed 218 * @return a map of topic names to float values from 0 to 1 sorted by reverse value. 219 */ 220 public Map<String, Float> getSimilarities(String allThePets) { 221 return getSimilarities(tokenize(allThePets)); 222 } 223 224 protected float tfidfNorm(String topicName) { 225 Float norm = cachedTopicTfIdfNorm.get(topicName); 226 if (norm == null) { 227 norm = normOf(tfidf(topicName)); 228 cachedTopicTfIdfNorm.put(topicName, norm); 229 } 230 return norm.floatValue(); 231 } 232 233 protected float[] tfidf(String topicName) { 234 float[] tfidf = (float[]) cachedTopicTfIdf.get(topicName); 235 if (tfidf == null) { 236 tfidf = getTfIdf((long[]) topicTermCount.get(topicName)); 237 cachedTopicTfIdf.put(topicName, tfidf); 238 } 239 return tfidf; 240 } 241 242 protected float[] getTfIdf(long[] counts) { 243 float[] idf = getIdf(); 244 float[] tfidf = new float[counts.length]; 245 long sum = sum(counts); 246 if (sum == 0) { 247 return tfidf; 248 } 249 for (int i = 0; i < counts.length; i++) { 250 tfidf[i] = ((float) counts[i]) / sum * idf[i]; 251 } 252 return tfidf; 253 } 254 255 protected float[] getIdf() { 256 if (cachedIdf == null) { 257 float[] idf = new float[allTermCounts.length]; 258 for (int i = 0; i < allTermCounts.length; i++) { 259 if (allTermCounts[i] == 0) { 260 idf[i] = 0; 261 } else { 262 idf[i] = (float) Math.log1p(((float) totalTermCount) / allTermCounts[i]); 263 } 264 } 265 // atomic update to ensure thread-safeness 266 cachedIdf = idf; 267 } 268 return cachedIdf; 269 } 270 271 public int getDimension() { 272 return dim; 273 } 274 275 /** 276 * Utility method to initialize the parameters from a set of UTF-8 encoded text files with names used as topic 277 * names. 278 * <p> 279 * The content of the file to assumed to be lines of terms separated by whitespaces without punctuation. 280 * 281 * @param folder 282 */ 283 public void learnFiles(File folder) throws IOException { 284 if (!folder.isDirectory()) { 285 throw new IOException(String.format("%s is not a folder", folder.getAbsolutePath())); 286 } 287 for (File file : folder.listFiles()) { 288 if (file.isDirectory()) { 289 continue; 290 } 291 String topicName = file.getName(); 292 if (topicName.contains(".")) { 293 topicName = topicName.substring(0, topicName.indexOf('.')); 294 } 295 log.info(String.format("About to analyze file %s", file.getAbsolutePath())); 296 FileInputStream is = new FileInputStream(file); 297 try { 298 BufferedReader reader = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8"))); 299 String line = reader.readLine(); 300 int i = 0; 301 while (line != null) { 302 update(topicName, line); 303 line = reader.readLine(); 304 i++; 305 if (i % 10000 == 0) { 306 log.info(String.format("Analyzed %d lines from '%s'", i, file.getAbsolutePath())); 307 } 308 } 309 } finally { 310 is.close(); 311 } 312 } 313 } 314 315 /** 316 * Save the model to a compressed binary format on the filesystem. 317 * 318 * @param file where to write the model 319 */ 320 public void saveToFile(File file) throws IOException { 321 FileOutputStream out = new FileOutputStream(file); 322 try { 323 saveToStream(out); 324 } finally { 325 out.close(); 326 } 327 } 328 329 /** 330 * Save a compressed binary representation of the trained model. 331 * 332 * @param out the output stream to write to 333 */ 334 public void saveToStream(OutputStream out) throws IOException { 335 if (updateDisabled) { 336 throw new IllegalStateException("model in disabled update mode cannot be saved"); 337 } 338 invalidateCache(); 339 GZIPOutputStream gzOut = new GZIPOutputStream(out); 340 ObjectOutputStream objOut = new ObjectOutputStream(gzOut); 341 objOut.writeObject(this); 342 gzOut.finish(); 343 } 344 345 /** 346 * Load a TfIdfCategorizer instance from it's compressed binary representation. 347 * 348 * @param in the input stream to read from 349 * @return a new instance with parameters coming from the saved version 350 */ 351 public static TfIdfCategorizer load(InputStream in) throws IOException, ClassNotFoundException { 352 GZIPInputStream gzIn = new GZIPInputStream(in); 353 ObjectInputStream objIn = new ObjectInputStream(gzIn); 354 TfIdfCategorizer cat = (TfIdfCategorizer) objIn.readObject(); 355 log.info(String.format("Sucessfully loaded model with %d topics, dimension %d and density %f", 356 cat.getTopicNames().size(), cat.getDimension(), cat.getDensity())); 357 return cat; 358 } 359 360 public double getDensity() { 361 long sum = 0; 362 for (Object singleTopicTermCount : topicTermCount.values()) { 363 for (long c : (long[]) singleTopicTermCount) { 364 sum += c != 0L ? 1 : 0; 365 } 366 } 367 for (long c : allTermCounts) { 368 sum += c != 0 ? 1 : 0; 369 } 370 return ((double) sum) / ((topicNames.size() + 1) * getDimension()); 371 } 372 373 public Set<String> getTopicNames() { 374 return topicNames; 375 } 376 377 /** 378 * Load a TfIdfCategorizer instance from it's compressed binary representation from a named resource in the 379 * classloading path of the current thread. 380 * 381 * @param modelPath the path of the file model in the classloading path 382 * @return a new instance with parameters coming from the saved version 383 */ 384 public static TfIdfCategorizer load(String modelPath) throws IOException, ClassNotFoundException { 385 ClassLoader loader = Thread.currentThread().getContextClassLoader(); 386 return load(loader.getResourceAsStream(modelPath)); 387 } 388 389 public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException { 390 if (args.length < 2 || args.length > 3) { 391 System.out.println("Train a model:\n" + "First argument is the model filename (e.g. my-model.gz)\n" 392 + "Second argument is the path to a folder with UTF-8 text files\n" 393 + "Third optional argument is the dimension of the model"); 394 System.exit(0); 395 } 396 File modelFile = new File(args[0]); 397 TfIdfCategorizer categorizer; 398 if (modelFile.exists()) { 399 log.info("Loading model from: " + modelFile.getAbsolutePath()); 400 FileInputStream is = new FileInputStream(modelFile); 401 try { 402 categorizer = load(is); 403 } finally { 404 is.close(); 405 } 406 } else { 407 if (args.length == 3) { 408 categorizer = new TfIdfCategorizer(Integer.valueOf(args[2])); 409 } else { 410 categorizer = new TfIdfCategorizer(); 411 } 412 log.info("Initializing new model with dimension: " + categorizer.getDimension()); 413 } 414 categorizer.learnFiles(new File(args[1])); 415 log.info("Saving trained model to: " + modelFile.getAbsolutePath()); 416 categorizer.saveToFile(modelFile); 417 } 418 419 public List<String> guessCategories(String textContent, int maxSuggestions) { 420 return guessCategories(textContent, maxSuggestions, null); 421 } 422 423 public List<String> guessCategories(String textContent, int maxSuggestions, Double precisionThreshold) { 424 precisionThreshold = precisionThreshold == null ? ratioOverMedian : precisionThreshold; 425 Map<String, Float> sims = getSimilarities(tokenize(textContent)); 426 Float median = findMedian(sims); 427 List<String> suggested = new ArrayList<String>(); 428 for (Map.Entry<String, Float> sim : sims.entrySet()) { 429 double ratio = median != 0 ? sim.getValue() / median : 100; 430 if (suggested.size() >= maxSuggestions || ratio < precisionThreshold) { 431 break; 432 } 433 suggested.add(sim.getKey()); 434 } 435 return suggested; 436 } 437 438 public List<String> tokenize(String textContent) { 439 try { 440 List<String> terms = new ArrayList<String>(); 441 TokenStream tokenStream = getAnalyzer().tokenStream(null, textContent); 442 CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); 443 tokenStream.reset(); 444 while (tokenStream.incrementToken()) { 445 terms.add(charTermAttribute.toString()); 446 } 447 tokenStream.end(); 448 tokenStream.close(); 449 return terms; 450 } catch (IOException e) { 451 throw new IllegalStateException(e); 452 } 453 } 454 455 public static Map<String, Float> sortByDecreasingValue(Map<String, Float> map) { 456 List<Map.Entry<String, Float>> list = new LinkedList<Map.Entry<String, Float>>(map.entrySet()); 457 Collections.sort(list, new Comparator<Map.Entry<String, Float>>() { 458 public int compare(Map.Entry<String, Float> e1, Map.Entry<String, Float> e2) { 459 return -e1.getValue().compareTo(e2.getValue()); 460 } 461 }); 462 Map<String, Float> result = new LinkedHashMap<String, Float>(); 463 for (Map.Entry<String, Float> e : list) { 464 result.put(e.getKey(), e.getValue()); 465 } 466 return result; 467 } 468 469 public static Float findMedian(Map<String, Float> sortedMap) { 470 int remaining = sortedMap.size() / 2; 471 Float median = 0.0f; 472 for (Float value : sortedMap.values()) { 473 median = value; 474 if (remaining-- <= 0) { 475 break; 476 } 477 } 478 return median; 479 } 480 481}