Friday 24 February 2012

Recursing a constexpr

While reading "Want speed? Use constexpr meta-programming!" blog post, I started thinking about how to reduce the recursion depth of such calculations. After all, using a compiler option to increase the recursion depth limits is not very friendly (gcc sets the limit at 512 by default). But before we get to reducing the depth of something like is_prime_func() from the aforementioned post, it's best to start with something slightly easier.

Consider a function to compute the following series:


Let's begin by writing it without constexpr to be executed at runtime but keeping in mind that we will later convert it for compile time computation. We can write it as a recursive accumulate function to gain some generality:

template <typename Int, typename T, typename BinOp>
T accumulate(Int first, Int last, T init, BinOp op) {
    return (first >= last) ?
        init :
        accumulate(first+1, last, op(init, first), op);
}

int sum_squares(int a, int n) {
    return accumulate(1, n+1, 0, [a](int prev, int i) {
        return prev + a*i*i;
    });
}

accumulate's signature is very similar to std::accumulate except that instead of dealing with iterators into a sequence, it operates directly on the integral sequence. The sum_squares function calls accumulate() with a lambda expression that captures 'a'.

All is good except that accumulate() will produce a recursion depth equal to the length of the input sequence. We can change it to a logarithmic growth by recursively dividing the sequence into two halves:

template <typename Int, typename T, typename BinOp>
T accumulate(Int first, Int last, T init, BinOp op) {
    return
        (first >= last ? init :
            (first+1 == last) ?
                op(init, first) :
                accumulate((first + last) / 2, last,
                    accumulate(first, (first + last) / 2, init, op), op));
}

At this point we should be able to just prefix function signatures with constexpr and get the magic of compile time computation. Let's give it a try:


template <typename Int, typename T, typename BinOp>
constexpr T accumulate(Int first, Int last, T init, BinOp op) {
    return
        (first >= last ? init :
            (first+1 == last) ?
                op(init, first) :
                accumulate((first + last) / 2, last,
                    accumulate(first, (first + last) / 2, init, op), op));
}

constexpr int sum_squares(int a, int n) {
    return accumulate(1, n+1, 0, [a](int prev, int i) {
        return prev + a*i*i;
    });
}

int main() {
    static_assert(sum_squares(3, 600) == 216540300, "");
    return 0;
}

$ g++ -std=c++0x test.cpp
test.cpp: In function ‘int sum_squares(int, int)’:
test.cpp:17:1: error: ‘constexpr T accumulate(Int, Int, T, BinOp) [with Int = int, T = int, BinOp = sum_squares(int, int)::<lambda(int, int)>]’ is not ‘constexpr’
test.cpp: In function ‘int main()’:
test.cpp:21:2: error: non-constant condition for static assertion
test.cpp:21:34: error: ‘int sum_squares(int, int)’ is not a constexpr function

For some mysterious reason, C++11 Standard explicitly forbids lambda expressions from being used in constexpr expressions. So we have to fall back to good old functors:

struct add_sq {
    int a;
    constexpr int operator()(int p, int i) const { return p + a*i*i; }
};

constexpr int sum_squares(int a, int n) {
    return accumulate(1, n+1, 0, add_sq{ a });
}

Generalization to for-loop

It is interesting to note that our accumulate() function can work for arbitrary for-loops of the form:
for( int i = start; i < end; i++ )
    body;
To see that, observe that a loop starts with some initial state and then mutates it on every iteration in the body of the loop. Functionally, we can view this mutation as applying an operation of two arguments, an old state and iteration index, to produce the new state. This is precisely what our accumulate function does. Of course, the state might encompass more than just a scalar and we would need a composite type to hold the individual elements. It would be great if std::tuple could work with constexpr but since it doesn't, it's easy enough to define a struct to hold the necessary state.


Conclusion

As we have seen, it is straightforward to convert a recursive function with linear depth into one with logarithmic depth if the number of recursions is known at the point of its invocation. This makes it easy to convert a classic for-loop into recursive form with manageable depth. Next time we'll look at how to convert a while-loop, where the number of iterations is not initially known, into logarithmic depth recursion and in the process define the promised is_prime(n) capable of dealing with very large numbers.