MLTea/LogTable

From CLSP Wiki
Jump to: navigation, search

Addition in the log domain

It's common to store probabilities in the log domain to avoid underflow, especially when using structured models such as PCFGs and HMMs. Multiplication in the log domain is easy (it just becomes addition):

log(a) * log(b) = log(a+b)

Addition in the log domain is a little more complex. The basic formula is:

log(a) + log(b) = log(exp(a)+exp(b))

This can result in underflow. There is an improved formula which avoids this to some extent (you must ensure that a > b when using this formula):

log(a) + log(b) = log(a) + log(1 + exp(b-a))

There is specialized function for computing log(1+x) in C/C++ called "log1p()". It's more accurate, so it should be used whenever possible. Here is some code for log domain addition in C++:

double logAdd(double a, double b)
{
   if (b > a) // Swap a and b if b is greater than a
   {
      double temp = a;
      a = b;
      b = temp;
   }
   return a + log1p(exp(b-a));
}

Calls to log(), log1p() and exp() are relatively expensive, so if you are running something like EM on an HMM, this can account for a large percentage of your runtime. It's possible to precompute the result of "log1p(exp(b-a))" for many values of b-a, and perform a table lookup when adding two numbers in the log domain. Since b must be less than or equal to a, b-a is less than or equal to 0, and exp(b-a) quickly becomes very small as b-a gets smaller. This gives us a reasonable way to set the range for the lookup table. Here is my C++ code for creating and accessing this lookup table to perform approximate addition in the log domain:

#define LOG_ADD_TABLE_SIZE 60000 // Number of entries in the table
#define LOG_ADD_MIN -64.0 // Smallest value for b-a

class LogTable
{
private:
   const double logAddInc;
   const double invLogAddInc;
   double logAddTable[LOG_ADD_TABLE_SIZE+1];

public:
   LogTable() : logAddInc(-LOG_ADD_MIN/LOG_ADD_TABLE_SIZE),
      invLogAddInc(LOG_ADD_TABLE_SIZE / -LOG_ADD_MIN),
   {
      for(int i = 0; i <= LOG_ADD_TABLE_SIZE; i++)
         logAddTable[i] = log1p(exp((i * logAddInc) + LOG_ADD_MIN));
   }
       
   double logAdd(double a, double b) const
   {
      if (b > a) 
      {
         double temp = a;
         a = b;
         b = temp;
      }
      double negDiff = (b - a) - LOG_ADD_MIN;
      if (negDiff < 0.0)
         return a;
      return a + logAddTable[(int)(negDiff * invLogAddInc)];
   }
};

Below is a port of the above C++ code to Java:

 /**
  * A port of Jason Smith's C++ LogTable code.
  */
 public class LogAddTable {
   private static final int LOG_ADD_TABLE_SIZE = 60000; // Number of entries in the table
   private static final double LOG_ADD_MIN = -64; // Smallest value for b-a
   private static final double logAddInc = -LOG_ADD_MIN / LOG_ADD_TABLE_SIZE;
   private static final double invLogAddInc = LOG_ADD_TABLE_SIZE / -LOG_ADD_MIN;
   private static final double[] logAddTable = new double[LOG_ADD_TABLE_SIZE + 1];
   static {
     for (int i = 0; i <= LOG_ADD_TABLE_SIZE; i++) {
       logAddTable[i] = Math.log1p(Math.exp((i * logAddInc) + LOG_ADD_MIN));
     }
   }
   public static double logAdd(double a, double b) {
     if (b > a) {
       double temp = a;
       a = b;
       b = temp;
     }
     double negDiff = b - a - LOG_ADD_MIN;
     if (negDiff < 0.0) {
       return a;
     }
     return a + logAddTable[(int) (negDiff * invLogAddInc)];
   }
 }