Friday 24 February 2012

Recursing a constexpr

While reading "Want speed? Use constexpr meta-programming!" blog post, I started thinking about how to reduce the recursion depth of such calculations. After all, using a compiler option to increase the recursion depth limits is not very friendly (gcc sets the limit at 512 by default). But before we get to reducing the depth of something like is_prime_func() from the aforementioned post, it's best to start with something slightly easier.

Consider a function to compute the following series:


Let's begin by writing it without constexpr to be executed at runtime but keeping in mind that we will later convert it for compile time computation. We can write it as a recursive accumulate function to gain some generality:

template <typename Int, typename T, typename BinOp>
T accumulate(Int first, Int last, T init, BinOp op) {
    return (first >= last) ?
        init :
        accumulate(first+1, last, op(init, first), op);
}

int sum_squares(int a, int n) {
    return accumulate(1, n+1, 0, [a](int prev, int i) {
        return prev + a*i*i;
    });
}

accumulate's signature is very similar to std::accumulate except that instead of dealing with iterators into a sequence, it operates directly on the integral sequence. The sum_squares function calls accumulate() with a lambda expression that captures 'a'.

All is good except that accumulate() will produce a recursion depth equal to the length of the input sequence. We can change it to a logarithmic growth by recursively dividing the sequence into two halves:

template <typename Int, typename T, typename BinOp>
T accumulate(Int first, Int last, T init, BinOp op) {
    return
        (first >= last ? init :
            (first+1 == last) ?
                op(init, first) :
                accumulate((first + last) / 2, last,
                    accumulate(first, (first + last) / 2, init, op), op));
}

At this point we should be able to just prefix function signatures with constexpr and get the magic of compile time computation. Let's give it a try:


template <typename Int, typename T, typename BinOp>
constexpr T accumulate(Int first, Int last, T init, BinOp op) {
    return
        (first >= last ? init :
            (first+1 == last) ?
                op(init, first) :
                accumulate((first + last) / 2, last,
                    accumulate(first, (first + last) / 2, init, op), op));
}

constexpr int sum_squares(int a, int n) {
    return accumulate(1, n+1, 0, [a](int prev, int i) {
        return prev + a*i*i;
    });
}

int main() {
    static_assert(sum_squares(3, 600) == 216540300, "");
    return 0;
}

$ g++ -std=c++0x test.cpp
test.cpp: In function ‘int sum_squares(int, int)’:
test.cpp:17:1: error: ‘constexpr T accumulate(Int, Int, T, BinOp) [with Int = int, T = int, BinOp = sum_squares(int, int)::<lambda(int, int)>]’ is not ‘constexpr’
test.cpp: In function ‘int main()’:
test.cpp:21:2: error: non-constant condition for static assertion
test.cpp:21:34: error: ‘int sum_squares(int, int)’ is not a constexpr function

For some mysterious reason, C++11 Standard explicitly forbids lambda expressions from being used in constexpr expressions. So we have to fall back to good old functors:

struct add_sq {
    int a;
    constexpr int operator()(int p, int i) const { return p + a*i*i; }
};

constexpr int sum_squares(int a, int n) {
    return accumulate(1, n+1, 0, add_sq{ a });
}

Generalization to for-loop

It is interesting to note that our accumulate() function can work for arbitrary for-loops of the form:
for( int i = start; i < end; i++ )
    body;
To see that, observe that a loop starts with some initial state and then mutates it on every iteration in the body of the loop. Functionally, we can view this mutation as applying an operation of two arguments, an old state and iteration index, to produce the new state. This is precisely what our accumulate function does. Of course, the state might encompass more than just a scalar and we would need a composite type to hold the individual elements. It would be great if std::tuple could work with constexpr but since it doesn't, it's easy enough to define a struct to hold the necessary state.


Conclusion

As we have seen, it is straightforward to convert a recursive function with linear depth into one with logarithmic depth if the number of recursions is known at the point of its invocation. This makes it easy to convert a classic for-loop into recursive form with manageable depth. Next time we'll look at how to convert a while-loop, where the number of iterations is not initially known, into logarithmic depth recursion and in the process define the promised is_prime(n) capable of dealing with very large numbers.

2 comments:

  1. Reducing recursion depth from linear to logarithmic is an interesting idea for constexpr functions. I also liked the similarities between a loop and accumulate.

    However, I do not agree with your claim that arbitrary for loop can be turned into a sequence of calls to accumulate. At-least not in the form you have shown. For instance, consider break and continue keywords. I don't think it is easy to simulate them using accumulate. Exceptions could be used. Also, when state is non-scalar (e.g., std::array), accumulate may be much less efficient than regular for loop because T is returned by value.

    Also, later in the article you seem to mix run-time computations that for loops usually do with constexpr, which is obviously not possible.

    ReplyDelete
  2. @Sumant: You are right, it generalizes only to certain types of for-loops (determinate loops), where the number of iterations is known a priori (start..end). This excludes 'break' statements. My next post discusses while-loops which allow for this type of behavior. 'continue' statement just skips part of the body and can be implemented as a guarded expression in the BinOp.

    The body of the loop must obviously be possible to write in functional form (more precisely as a constexpr) so that does rule out a lot of code (e.g. using std::cout or having a goto). I was concentrating on "algorithmic" code that can usually be written in functional style and useful in meta-programming.

    ReplyDelete