Metaprogramming
Metaprogramming is the technique of writing a computer program that operates on other programs. Systems such as compilers and program analyzers can be considered metaprograms, since they take other programs as input. The forms of metaprogramming we will discuss here are specifically concerned with generating code to be included as part of a program. In a sense, they can be considered rudimentary compilers.
Macros and Code Generation
A macro is a rule that translates an input sequence into some replacement output sequence. This translation process is called macro expansion, and some languages provide macros as part of their specification. The macro facility may be implemented as a preprocessing step, where macro expansion occurs before lexical and syntactic analysis, or it may be incorporated as part of syntax analysis or a later translation step.
One of the most widely used macro systems is the C preprocessor (CPP),
which is included in both C and C++ as the first step in processing
a program. Preprocessor directives begin with a hash symbol and
include #include
, #define
, #if
, among others. For instance,
the following defines a function-like macro to swap two items:
#define SWAP(a, b) { auto tmp = b; b = a; a = tmp; }
We can then use the macro as follows:
int main() {
int x = 3;
int y = 4;
SWAP(x, y);
cout << x << " " << y << endl;
}
Running the resulting executable will print a 4, followed by a 3.
The results of macro expansion can be obtained by passing the -E
flag to g++
:
$ g++ -E <source>
However, the results can be quite messy if there are #include
s,
since that directive pulls in the code from the given file.
CPP macros perform text replacement, so that the code above is equivalent to:
int main() {
int x = 3;
int y = 4;
{ auto tmp = y; y = x; x = tmp; };
cout << x << " " << y << endl;
}
The semicolon following the use of the SWAP
macro remains,
denoting an empty statement. This is a problem, however, in contexts
that require a single statement, such as a conditional branch that is
not enclosed by a block:
if (x < y)
SWAP(x, y);
else
cout << "no swap" << endl;
A common idiom to avoid this problem is to place the expansion code for the macro inside of a do/while:
#define SWAP(a, b) do { \
auto tmp = b; \
b = a; \
a = tmp; \
} while (false)
Here, we’ve placed a backslash at the end of a line to denote that the
next line should be considered a continuation of the previous one. A
do/while loop syntactically ends with a semicolon, so that the
semicolon in SWAP(x, y);
is syntactically part of the do/while
loop. Thus, the expanded code has the correct syntax:
if (x < y)
do { auto tmp = b; b = a; a = tmp; } while (false);
else
cout << "no swap" << endl;
While textual replacement is useful, it does have drawbacks, stemming from the fact that though the macros are syntactically function like, they do not behave as functions. Specifically, they do not treat arguments as their own entities, and they do not introduce a separate scope. Consider the following example:
int main() {
int x = 3;
int y = 4;
int z = 5;
SWAP(x < y ? x : y, z);
cout << x << " " << y << " " << z << endl;
}
Running the resulting program produces the unexpected result:
3 4 3
Using g++ -E
, we can see what the preprocessed code looks like.
Looking only at the output for main()
, we find:
int main() {
int x = 3;
int y = 4;
int z = 5;
do {
auto tmp = z;
z = x < y ? x : y;
x < y ? x : y = tmp;
} while (false);
cout << x << " " << y << " " << z << endl;
}
Here, we’ve manually added line breaks and whitespace to make the output more readable; the preprocessor itself places the macro output on a single line. The culprit is the last generated statement:
x < y ? x : y = tmp;
In C++, the conditional operator ? :
and the assignment operator
=
have the same precedence and associate right to left, so this is
equivalent to:
x < y ? x : (y = tmp);
Since x < y
, no assignment happens here. Thus, the value of x
is unchanged.
We can fix this problem by placing parentheses around each use of a macro argument:
#define SWAP(a, b) do { \
auto tmp = (b); \
(b) = (a); \
(a) = tmp; \
} while (false)
This now produces the expected result, as the operations are explicitly associated by the parentheses:
int main() {
int x = 3;
int y = 4;
int z = 5;
do {
auto tmp = (z);
(z) = (x < y ? x : y);
(x < y ? x : y) = tmp;
} while (false);
cout << x << " " << y << " " << z << endl;
}
A second problem, however, is not as immediately fixable. Consider
what happens when we apply the SWAP
macro to a variable named
tmp
:
int main() {
int x = 3;
int tmp = 4;
SWAP(tmp, x);
cout << x << " " << tmp << endl;
}
Running this code results in:
3 4
No swap occurs! Again, using g++ -E
to examine the output, we see
(modulo spacing):
int main() {
int x = 3;
int tmp = 4;
do {
auto tmp = (x);
(x) = (tmp);
(tmp) = tmp;
} while (false);
cout << x << " " << tmp << endl;
}
Since the temporary variable used by SWAP
has the same name as an
argument, the temporary captures the occurrences of the argument in
the generated code. This is because the macro merely performs text
substitution, which does not ensure that names get resolved to the
appropriate scope. (Thus, macros do not actually use call by name,
which does ensure that a name in an argument resolves to the
appropriate scope.) The reliance on text replacement makes CPP a
non-hygienic macro system. Other systems, such as Scheme’s, are
hygienic, creating a separate scope for names introduced by a macro
and ensuring that arguments are not captured.
Scheme Macros
The macro system defined as part of the R5RS Scheme specification is
hygienic. A macro is introduced by one of the define-syntax
,
let-syntax
, or letrec-syntax
forms, and it binds the given
name to the macro. As an example, the following is a definition of
let
as a macro:
(define-syntax let
(syntax-rules ()
((let ((name val) ...)
body1 body2 ...
)
((lambda (name ...)
body1 body2 ...
)
val ...
)
)
)
)
The syntax-rules
from specifies the rules for the macro
transformation. The first argument is a list of literals that must
match between the pattern of the rule and the input. An example is the
else
identifier inside of a cond
form. In this case, however,
there are no literals. The remaining arguments to syntax-rules
specify transformations. The first item in a transformation is the
input pattern, and the second is the output pattern. The ...
acts
like a Kleene star, matching the previous item to zero or more
occurrences in the input. The names that appear in an input pattern
but are not in the list of literals, excepting the first item that is
the macro name, are hygienic variables that match input elements. The
variables can then be referenced in the output pattern to specify how
to construct the output.
Evaluating the expression above in the global environment binds the
name let
to a macro that translates to a lambda
.
Identifiers introduced by the body of a macro are guaranteed to avoid
conflict with other identifiers, and the interpreter often renames
identifiers to avoid such a conflict. Consider the following
definition of a swap
macro:
(define-syntax swap
(syntax-rules ()
((swap a b)
(let ((tmp b))
(set! b a)
(set! a tmp)
)
)
)
)
This translates a use of swap
to an expression that swaps the
two arguments through a temporary variable tmp
. Thus:
> (define x 3)
> (define y 4)
> (swap x y)
> x
4
> y
3
However, unlike CPP macros, the tmp
introduced by the swap
macro is distinct from any other tmp
:
> (define tmp 5)
> (swap x tmp)
> x
5
> tmp
4
Because macros are hygienic in Scheme, we get the expected behavior.
In order to support macros, the evaluation procedure of the Scheme
interpreter evaluates the first item in a list, as usual. If it
evaluates to a macro, then the interpreter performs macro expansion on
the rest of the list without first evaluating the arguments. Any names
introduced by the expansion are placed in a separate scope from other
names. After expansion, the interpreter repeats the evaluation process
on the result of expansion, so that if the end result is a let
expression as in swap
above, the expression is evaluated.
A macro definition can specify multiple pattern rules. Combined with
the fact that the result of expansion is evaluated, this allows a
macro to be recursive, as in the following definition of let*
:
(define-syntax let*
(syntax-rules ()
((let* ()
body1 body2 ...
)
(let ()
body1 body2 ...
)
)
((let* ((name1 val1) (name2 val2) ...)
body1 body2 ...
)
(let ((name1 val1))
(let* ((name2 val2) ...)
body1 body2 ...
)
)
)
)
)
There is a base-case pattern for when the let*
has no bindings, in
which case it translates directly into a let
. There is also a
recursive pattern for when there is at least one binding, in which
case the let*
translates into a simpler let*
nested within a
let
. The ellipsis (...
) in a macro definition is similar to a
Kleene star (*
) in a regular expression, denoting that the
preceding item can be matched zero or more times. Thus, a let*
with a single binding matches the second pattern rule above, where
(name2 val2)
is matched zero times.
CPP Macros
We return our attention to CPP macros. Despite their non-hygienic nature, they can be very useful in tasks that involve metaprogramming.
CPP allows us to use #define
to define two types of macros,
object-like and function-like macros. An object-lke macro is a
simple text replacement, substituting one sequence of text for
another. Historically, a common use was to define constants:
#define PI 3.1415926535
int main() {
cout << "pi = " << PI << endl;
cout << "tau = " << PI * 2 << endl;
}
Better practice in C++ is to define a constant using const
or
constexpr
.
A function-like macro takes arguments, as in SWAP
above, and can
substitute the argument text into specific locations within the
replacement text.
A more complex example of using function-like macros is to abstract the definition of multiple pieces of code that follow the same pattern. Consider the definition of a type to represent a complex number:
struct Complex {
double real;
double imag;
};
ostream &operator<<(ostream &os, Complex c) {
return os << "(" << c.real << "+" << c.imag << "i)";
}
Suppose that in addition to the overloaded stream insertion operator
above, we wish to support the arithmetic operations +
, -
, and
*
. These operations all have the same basic form:
Complex operator <op>(Complex a, Complex b) {
return Complex{ <expression for real>, <expression for imag> };
}
Here, we’ve used uniform initialization syntax to initialize a
Complex
with values for its members. We can then write a
function-like macro to abstract this structure:
#define COMPLEX_OP(op, real_part, imag_part) \
Complex operator op(Complex a, Complex b) { \
return Complex{ real_part, imag_part }; \
}
The macro has arguments for each piece that differs between operations, namely the operator, the expression to compute the real part, and the expression to compute the imaginary part. We can use the macro as follows to define the operations:
COMPLEX_OP(+, a.real+b.real, a.imag+b.imag);
COMPLEX_OP(-, a.real-b.real, a.imag-b.imag);
COMPLEX_OP(*, a.real*b.real - a.imag*b.imag,
a.imag*b.real + a.real*b.imag);
As with our initial SWAP
implementation, the trailing semicolon is
extraneous but improves readability and interaction with syntax
highlighters. Running the code through the preprocessor with g++
-E
, we get (modulo spacing):
Complex operator +(Complex a, Complex b) {
return Complex{ a.real+b.real, a.imag+b.imag };
};
Complex operator -(Complex a, Complex b) {
return Complex{ a.real-b.real, a.imag-b.imag };
};
Complex operator *(Complex a, Complex b) {
return Complex{ a.real*b.real - a.imag*b.imag,
a.imag*b.real + a.real*b.imag };
};
We can then proceed to define operations between Complex
and
double
values. Again, we observe that such an operation has a
specific pattern:
Complex operator <op>(<type1> a, <type2> b) {
return <expr1> <op> <expr2>;
}
Here, <exprN>
is the corresponding argument converted to its
Complex
representation. We can abstract this using a macro:
#define REAL_OP(op, typeA, typeB, argA, argB) \
Complex operator op(typeA a, typeB b) { \
return argA op argB; \
}
We can also define a macro to convert from a double
to a
Complex
:
#define CONVERT(a) \
(Complex{ a, 0 })
We can then define our operations as follows:
REAL_OP(+, Complex, double, a, CONVERT(b));
REAL_OP(+, double, Complex, CONVERT(a), b);
REAL_OP(-, Complex, double, a, CONVERT(b));
REAL_OP(-, double, Complex, CONVERT(a), b);
REAL_OP(*, Complex, double, a, CONVERT(b));
REAL_OP(*, double, Complex, CONVERT(a), b);
Running this through the preprocessor, we get:
Complex operator +(Complex a, double b) { return a + (Complex{ b, 0 }); };
Complex operator +(double a, Complex b) { return (Complex{ a, 0 }) + b; };
Complex operator -(Complex a, double b) { return a - (Complex{ b, 0 }); };
Complex operator -(double a, Complex b) { return (Complex{ a, 0 }) - b; };
Complex operator *(Complex a, double b) { return a * (Complex{ b, 0 }); };
Complex operator *(double a, Complex b) { return (Complex{ a, 0 }) * b; };
We can now use complex numbers as follows:
int main() {
Complex c1{ 3, 4 };
Complex c2{ -1, 2 };
double d = 0.5;
cout << c1 + c2 << endl;
cout << c1 - c2 << endl;
cout << c1 * c2 << endl;
cout << c1 + d << endl;
cout << c1 - d << endl;
cout << c1 * d << endl;
cout << d + c1 << endl;
cout << d - c1 << endl;
cout << d * c1 << endl;
}
This results in:
(2+6i)
(4+2i)
(-11+2i)
(3.5+4i)
(2.5+4i)
(1.5+2i)
(3.5+4i)
(-2.5+-4i)
(1.5+2i)
Stringification and Concatenation
When working with macros, it can be useful to convert a macro argument to a string or to concatenate it with another token. For instance, suppose we wanted to write an interactive application that would read input from a user and perform the corresponding action. On complex numbers, the target functions may be as follows:
Complex Complex_conjugate(Complex c) {
return Complex{ c.real, -c.imag };
}
string Complex_polar(Complex c) {
return "(" + to_string(sqrt(pow(c.real, 2) + pow(c.imag, 2))) +
"," + to_string(atan(c.imag / c.real)) + ")";
}
The application would compare the user input to a string representing an action, call the appropriate function, and print out the result. This has the common pattern:
if (<input> == "<action>")
cout << Complex_<action>(<value>) << endl;
Here, we both need a string representation of the action, as well as
the ability to concatenate the Complex_
token with the action
token itself. We can define a macro for this pattern as follows:
#define ACTION(str, name, arg) \
if (str == #name) \
cout << Complex_ ## name(arg) << endl
The #
preceding a token is the stringification operator,
converting the token to a string. The ##
between Complex_
and
name
is the token pasting operator, concatenating the tokens on
either side.
We can then write our application code as follows:
Complex c1 { 3, 4 };
string s;
while (cin >> s) {
ACTION(s, conjugate, c1);
ACTION(s, polar, c1);
}
Running this through the preprocessor, we obtain the desired result:
Complex c1 { 3, 4 };
string s;
while (cin >> s) {
if (s == "conjugate") cout << Complex_conjugate(c1) << endl;
if (s == "polar") cout << Complex_polar(c1) << endl;
}
The Macro Namespace
One pitfall of using CPP macros is that they are not contained within
any particular namespace. In fact, a macro, as long as it is defined,
will replace any eligible token, regardless of where the token is
located. Thus, defining a macro is akin to making a particular
identifier act as a reserved keyword, unable to be used by the
programmer. (This is one reason why constants are usually better
defined as variables qualified const
or constexpr
than as
object-like macros.)
Several conventions are used to avoid polluting the global namespace.
The first is to prefix all macros with characters that are specific to
the library defining them in such a way as to avoid conflict with
other libraries. For instance, our complex-number macros may be
prefixed with COMPLEX_
to avoid conflicting with other macros or
identifiers. The second strategy is to undefine macros when they are
no longer needed, using the #undef
preprocessor directive. For
example, at the end of our library code, we may have the following:
#undef COMPLEX_OP
#undef REAL_OP
#undef CONVERT
#undef ACTION
This frees the identifiers to be used for other purposes in later code.
Code Generation
While macros allow us to generate code using the macro facilities provided by a language, there are some cases where such a facility is unavailable or otherwise insufficient for our purposes. In such a situation, it may be convenient to write a code generator in an external program, in the same language or in a different language. This technique is also called automatic programming.
As an example, the R5RS Scheme specification requires implementations
to provide combinations of car
and cdr
up to four levels deep.
For instance, (caar x)
should be equivalent to (car (car x))
,
and (caddar x)
should be equivalent to (car (cdr (cdr (car
x))))
. Aside from car
and cdr
themselves, there are 28
combinations that need to be provided, which would be tedious and
error-prone to write by hand. Instead, we can define the following
Python script to generate a Scheme library file:
import itertools
def cadrify(seq):
if len(seq):
return '(c{0}r {1})'.format(seq[0], cadrify(seq[1:]))
return 'x'
def defun(seq):
return '(define (c{0}r x) {1})'.format(''.join(seq), cadrify(seq))
for i in range(2, 5):
for seq in itertools.product(('a', 'd'), repeat=i):
print(defun(seq))
The cadrify()
function is a recursive function that takes in a
sequence such as ('a', 'd', 'a')
and constructs a call using the
first item and the recursive result of the rest of the sequence. In
this example, the latter is (cdr (car x))
, so the result would be
(car (cdr (car x)))
. The base case is in which the sequence is
empty, producing just x
.
The defun()
function takes in a sequence and uses it construct the
definition for the appropriate combination. It calls cadrify()
to
construct the body. For the sequence ('a', 'd', 'a')
, the result
is:
(define (cadar x) (car (cdr (car x))))
Finally, the loop at the end produces all combinations of 'a'
and
'd'
for each length. It uses the library function
itertools.product()
to obtain a sequence that is the i
th
power of the tuple ('a', 'd')
. For each combination, it calls
defun()
to generate the function for that combination.
Running the script results in:
(define (caar x) (car (car x)))
(define (cadr x) (car (cdr x)))
(define (cdar x) (cdr (car x)))
(define (cddr x) (cdr (cdr x)))
(define (caaar x) (car (car (car x))))
(define (caadr x) (car (car (cdr x))))
...
(define (cdddar x) (cdr (cdr (cdr (car x)))))
(define (cddddr x) (cdr (cdr (cdr (cdr x)))))
We can place the resulting code in a standard library to be loaded by the Scheme interpreter.
Template Metaprogramming
Template metaprogramming is a technique that uses templates to produce source code at compile time, which is then compiled with the rest of the program’s code. It generally refers to a form of compile-time execution that takes advantage of the language’s rules for template instantiation. Template metaprogramming is most common in C++, though a handful of other languages also enable it.
The key to template metaprogramming in C++ is template
specialization, which allows a specialized definition to be written
for instantiating a template with specific arguments. For example,
consider a class template that contains a static value
field that
is true if the template argument is int
but false otherwise. We
can write the generic template as follows:
template <class T>
struct is_int {
static const bool value = false;
};
We can now define a specialization for this template when the
argument is int
:
template <>
struct is_int<int> {
static const bool value = true;
};
The template parameter list in a specialization contains the
non-specialized parameters, if any. In the case above, there are none,
so it is empty. Then after the name of the template, we provide the
full set of arguments for the instantiation, in this case just
int
. We then provide the rest of the definition for the
instantiation.
Now when we use the template, the compiler uses the specialization if the template argument is compatible with the specialization, otherwise it uses the generic template:
cout << is_int<double>::value << endl;
cout << is_int<int>::value << endl;
This prints a 0 followed by a 1.
Template specialization enables us to write code that is conditional on a template argument. Combined with recursive instantiation, this results in template instantiation being Turing complete. Templates do not encode variables that are mutable, so template metaprogramming is actually a form of functional programming.
Pairs
As a more complex example, let us define pairs and lists that can be manipulated at compile time. The elements stored in these structures will be arbitrary types.
Before we proceed to define pairs, we construct a reporting mechanism that allows us to examine results at compile time. We arrange to include the relevant information in an error message generated by the compiler:
template <class A, int I>
struct report {
static_assert(I < 0, "report");
};
For simplicity, we make use of an integer template parameter, though
we could encode numbers using types instead. When instantiating the
report
template, the static_assert
raises an error if the
template argument I
is nonnegative. Consider the following:
report<int, 5> foo;
The compiler will report an error, indicating what instantiation
caused the static_assert
to fail. In Clang, we get an error like
the following:
pair.cpp:64:3: error: static_assert failed "report"
static_assert(I < 0, "report");
^ ~~~~~
pair.cpp:67:16: note: in instantiation of template class 'report<int, 5>'
requested here
report<int, 5> foo;
^
Using GCC, the error is as follows:
pair.cpp: In instantiation of 'struct report<int, 5>':
pair.cpp:67:16: required from here
main.cpp:64:3: error: static assertion failed: report
static_assert(I < 0, "report");
^
In both compilers, the relevant information is reported, which is that
the arguments to the report
template are int
and 5.
We can then define a pair template as follows:
template <class First, class Second>
struct pair {
using car = First;
using cdr = Second;
};
Within the template, we define type aliases car
and cdr
to
refer to the first and second items of the pair. Thus, pair<int,
double>::car
is an alias for int
, while pair<int,
double>::cdr
is an alias for double
.
We can also define type aliases to extract the first and second items from a pair:
template <class Pair>
using car_t = typename Pair::car;
template <class Pair>
using cdr_t = typename Pair::cdr;
The typename
keyword is required before Pair::car
and
Pair::cdr
, since we are using a nested type whose enclosing type
is dependent on a template parameter. In such a case, C++ cannot
determine that we are naming a type rather than a value, so the
typename
keyword explicitly indicates that it is a type. Using the
aliases above, car_t<pair<int, double>>
is an alias for int
,
while cdr_t<pair<int, double>>
is an alias for double
.
In order to represent recursive lists, we need a representation for the empty list:
struct nil {
};
We can now define a template to determine whether or not a list,
represented either by the empty list nil
or by a
nil
-terminated sequence of pair
s, is empty. We define a
generic template and then a specialization for the case of nil
as
the argument:
template <class List>
struct is_empty {
static const bool value = false;
};
template <>
struct is_empty<nil> {
static const bool value = true;
};
In order to use the field value
at compile time, it must be a
compile-time constant, which we can arrange by making it both static
and const
and initializing it with a compile-time constant. With
C++14, we can also define global variable templates to encode the
length of a list:
template <class List>
const bool is_empty_v = is_empty<List>::value;
The value of is_empty_v<nil>
is true, while is_empty<pair<int,
nil>>
is false. Then we can determine at compilation whether or not
a list is empty:
using x = pair<char, pair<int, pair<double, nil>>>;
using y = pair<float, pair<bool, nil>>;
using z = nil;
report<x, is_empty_v<x>> a;
report<y, is_empty_v<y>> b;
report<z, is_empty_v<z>> c;
Here, we introduce type aliases for lists, which act as immutable
compile-time variables. We then instantiate report
with a type and
whether or not it is empty. This results in the following error
messages from GCC:
pair.cpp: In instantiation of 'struct report<pair<char, pair<int,
pair<double, nil> > >, 0>':
pair.cpp:82:28: required from here
pair.cpp:73:3: error: static assertion failed: report
static_assert(I < 0, "report");
^~~~~~~~~~~~~
pair.cpp: In instantiation of 'struct report<pair<float, pair<bool,
nil> >, 0>':
pair.cpp:83:28: required from here
pair.cpp:73:3: error: static assertion failed: report
pair.cpp: In instantiation of 'struct report<nil, 1>':
pair.cpp:84:28: required from here
pair.cpp:73:3: error: static assertion failed: report
Examining the integer argument of report
, we see that the lists
pair<char, pair<int, pair<double, nil>>>
and pair<float,
pair<bool, nil>>
are not empty, but the list nil
is.
We can compute the length of a list using recursion:
template <class List>
struct length {
static const int value = length<cdr_t<List>>::value + 1;
};
template <>
struct length<nil> {
static const int value = 0;
};
template <class List>
const int length_v = length<List>::value;
Here, we are using a value from a recursive instantiation of the
length
struct. Since value
is initialized with an expression
consisting of an operation between compile-time constants, it is also
a compile-time constant. The recursion terminates at the
specialization for length<nil>
, where the value
member is
directly initialized to 0. As with is_empty_v
, we define a
variable template length_v
to encode the result. We can compute
and report the length of the x
type alias:
report<x, length_v<x>> d;
The first argument to report
is arbitrary, since we only care
about the second argument, so we just pass x
itself. We get:
pair.cpp: In instantiation of 'struct report<pair<char, pair<int,
pair<double, nil> > >, 3>':
pair.cpp:85:26: required from here
pair.cpp:73:3: error: static assertion failed: report
The relevant information is that the length is 3.
We can define even more complex manipulation on lists. For instance, we can reverse a list as follows:
template <class List, class SoFar>
struct reverse_helper {
using type =
typename reverse_helper<cdr_t<List>,
pair<car_t<List>, SoFar>>::type;
};
template <class SoFar>
struct reverse_helper<nil, SoFar> {
using type = SoFar;
};
template <class List>
using reverse_t = typename reverse_helper<List, nil>::type;
Here, we use a helper template to perform the reversal, where the
first template argument is the remaining list and the second is the
reversed list so far. In each step, we compute a new partial result as
pair<car_t<List>, SoFar>
, adding the first item in the remaining
list to the front of the previous partial result. Then cdr_t<List>
is the remaining list excluding the first item.
The base case of the recursion is when the remaining list is nil
,
in which case the final result is the same as the partial result. We
accomplish this with a partial class template specialization, which
allows us to specialize only some of the arguments to a class template
[1]. In reverse_helper
, we specialize the first argument, so that
any instantiation of reverse_helper
where the first argument is
nil
will use the specialization. The specialization retains a
template parameter, which is included in its parameter list. The full
argument list appears after the template name, including both the
specialized and unspecialized arguments.
We seed the whole computation in the reverse_t
alias template with
the original list and empty partial result. We apply reverse_t
to
x
:
report<reverse_t<x>, 0> e;
Here, the second argument is an arbitrary nonnegative value. We get:
pair.cpp: In instantiation of 'struct report<pair<double, pair<int,
pair<char, nil> > >, 0>':
pair.cpp:86:27: required from here
pair.cpp:73:3: error: static assertion failed: report
As a last example, we can now write a template to append two lists:
template <class List1, class List2>
struct append {
using type =
pair<car_t<List1>,
typename append<cdr_t<List1>, List2>::type>;
};
template <class List2>
struct append<nil, List2> {
using type = List2;
};
template <class List1, class List2>
using append_t = typename append<List1, List2>::type;
Here, the template appends the second argument to the first argument.
This is accomplished by prepending the first item of the first list to
the result of appending the second list to the rest of the first list.
The recursion terminates when the first list is empty. Applying
append_t
to x
and y
:
report<append_t<x, y>, 0> f;
We get:
pair.cpp: In instantiation of 'struct report<pair<char, pair<int,
pair<double, pair<float, pair<bool, nil> > > > >, 0>':
pair.cpp:87:29: required from here
pair.cpp:73:3: error: static assertion failed: report
Numerical Computations
Using just recursion and template specialization, we could encode numbers using a system like Church numerals. However, C++ also supports integral template parameters, so we can perform compile-time numerical computations using an integer parameter rather than just types.
As an example, consider the following definition of a template to compute the factorial of the template parameter:
template <int N>
struct factorial {
static const long long value = N * factorial<N - 1>::value;
};
template <>
struct factorial<0> {
static const long long value = 1;
};
The generic template multiplies its template argument N
by the
result of computing factorial on N - 1
. The base case is provided
by the specialization for when the argument is 0, where the factorial
is 1.
Here, we’ve used a long long
to hold the computed value, so that
larger results can be computed than can be represented by int
. We
define a template to report a result as follows:
template <long long N>
struct report {
static_assert(N > 0 && N < 0, "report");
};
The condition of the static_assert
is written to depend on the
template parameter so that the assertion fails during instantiation,
rather than before. Then if we compute the factorial of 5:
report<factorial<5>::value> a;
We get:
factorial.cpp: In instantiation of 'struct report<120ll>':
factorial.cpp:37:34: required from here
factorial.cpp:33:3: error: static assertion failed: report
static_assert(N > 0 && N < 0, "report");
^
This shows that the result is 120.
We can use a macro to make our program more generic, encoding the
argument to factorial
as a macro that can be defined at compile
time:
report<factorial<NUM>::value> a;
We can even provide a default value:
#ifndef NUM
#define NUM 5
#endif
Then at the command line, we can specify the argument as follows:
$ g++ --std=c++14 factorial.cpp -DNUM=20
factorial.cpp: In instantiation of 'struct report<2432902008176640000ll>':
factorial.cpp:27:33: required from here
factorial.cpp:23:3: error: static assertion failed: report
static_assert(N > 0 && N < 0, "report");
^
The command-line argument -D
in GCC and Clang allows us to define
a macro from the command line.
Suppose we now attempt to compute the factorial of a negative number:
$ g++ --std=c++14 factorial.cpp -DNUM=-1
factorial.cpp: In instantiation of 'const long long int
factorial<-900>::value':
factorial.cpp:23:36: recursively required from 'const long long int
factorial<-2>::value'
factorial.cpp:23:36: required from 'const long long int
factorial<-1>::value'
factorial.cpp:37:27: required from here
factorial.cpp:23:36: fatal error: template instantiation depth exceeds
maximum of 900 (use -ftemplate-depth= to increase the maximum)
static const long long value = N * factorial<N - 1>::value;
^
compilation terminated.
We see that the recursion never reaches the base case of 0. Instead, the compiler terminates compilation when the recursion depth reaches its limit. We can attempt to add an assertion that the template argument is non-negative as follows:
template <int N>
struct factorial {
static_assert(N >= 0, "argument to factorial must be non-negative");
static const long long value = N * factorial<N - 1>::value;
};
However, this does not prevent the recursive instantiation, so that what we get is an even longer set of error messages:
factorial.cpp: In instantiation of 'struct factorial<-1>':
factorial.cpp:38:25: required from here
factorial.cpp:23:3: error: static assertion failed: argument to factorial
must be non-negative
static_assert(N >= 0, "argument to factorial must be non-negative");
^
...
factorial.cpp: In instantiation of 'struct factorial<-900>':
factorial.cpp:24:36: recursively required from 'const long long int
factorial<-2>::value'
factorial.cpp:24:36: required from 'const long long int
factorial<-1>::value'
factorial.cpp:38:27: required from here
factorial.cpp:23:3: error: static assertion failed: argument to factorial
must be non-negative
factorial.cpp: In instantiation of 'const long long int
factorial<-900>::value':
factorial.cpp:24:36: recursively required from 'const long long int
factorial<-2>::value'
factorial.cpp:24:36: required from 'const long long int
factorial<-1>::value'
factorial.cpp:38:27: required from here
factorial.cpp:24:36: fatal error: template instantiation depth exceeds
maximum of 900 (use -ftemplate-depth= to increase the maximum)
static const long long value = N * factorial<N - 1>::value;
^
compilation terminated.
Here, we have removed the intermediate error messages between -1 and -900.
In order to actually prevent recursive instantiation when the argument
is negative, we can offload the actual recursive work to a helper
template. We can then check that the argument is non-negative in
factorial
, converting the argument to 0 if it is negative:
template <int N>
struct factorial_helper {
static const long long value = N * factorial_helper<N - 1>::value;
};
template <>
struct factorial_helper<0> {
static const long long value = 1;
};
template <int N>
struct factorial {
static_assert(N >= 0, "argument to factorial must be non-negative");
static const long long value = factorial_helper<N >= 0 ? N : 0>::value;
};
The key here is that factorial
only instantiates
factorial_helper<0>
if the argument of factorial
is
nonnegative. Thus, we get:
$ g++ --std=c++14 factorial.cpp -DNUM=-1
factorial.cpp: In instantiation of 'struct factorial<-1>':
factorial.cpp:38:24: required from here
factorial.cpp:17:3: error: static assertion failed: argument to factorial
must be non-negative
static_assert(N >= 0, "argument to factorial must be non-negative");
^
factorial.cpp: In instantiation of 'struct report<1ll>':
factorial.cpp:38:33: required from here
factorial.cpp:34:3: error: static assertion failed: report
static_assert(N > 0 && N < 0, "report");
^
We no longer have an unbounded recursion. This demonstrates how we can achieve conditional compilation, even without a built-in conditional construct.
An alternative strategy is to use a second, defaulted template argument that tracks whether or not the first argument is positive:
template <int N, bool /*Positive*/ = (N > 0)>
struct factorial {
static const long long value = N * factorial<N - 1>::value;
};
template <int N>
struct factorial<N, false> {
static const long long value = 1;
};
When we instantiate factorial
with a positive argument, as in
factorial<5>
, the second argument is defaulted to true. Since that
does not match the partial specialization, the instantiation uses the
generic version of the template. On the other hand, if we instantiate
the template with a non-positive argument, such as in
factorial<0>
, the second argument defaults to false, resulting in
the partial specialization being used. Thus, the defaulted argument
serves to control whether the generic or specialized version is used.
Since it’s not used for anything else, we need not name the argument,
but we have included the /*Positive*/
comment to document the
argument’s purpose.
As another example of a numerical computation, the following computes Fibonacci numbers at compile time. For simplicity, we do not implement error checking for negative arguments:
template <int N>
struct fib {
static const long long value = fib<N - 1>::value + fib<N - 2>::value;
};
template <>
struct fib<1> {
static const long long value = 1;
};
template <>
struct fib<0> {
static const long long value = 0;
};
We have two base cases, provided by separate specializations for when the argument is 0 or 1. As with factorial, we use a macro to represent the input:
report<fib<NUM>::value> a;
We can then specify the input at the command line:
$ g++ --std=c++14 fib.cpp -DNUM=7
fib.cpp: In instantiation of 'struct report<13ll>':
fib.cpp:26:27: required from here
fib.cpp:22:3: error: static assertion failed: report
static_assert(N > 0 && N < 0, "report");
^
We can even provide the largest input for which the Fibonacci number
is representable as a long long
:
$ g++ --std=c++14 fib.cpp -DNUM=92
fib.cpp: In instantiation of 'struct report<7540113804746346429ll>':
fib.cpp:26:27: required from here
fib.cpp:22:3: error: static assertion failed: report
static_assert(N > 0 && N < 0, "report");
^
This computation only takes a fraction of a second, since the C++ compiler only instantiates a template once for a given set of arguments within a single translation unit. Thus, the compiler automatically performs memoization, saving the result of a single computation rather than repeating it.
Templates and Function Overloading
While function templates can also be specialized, a function template can also be overloaded with a non-template function. In performing overload resolution, C++ prefers a non-template function over a template instantiation, as long as the parameter and return types of the template instantiation are not superior to the non-template in the given context.
As an example, consider the following function template to convert a value to a string representation:
template <class T>
string to_string(const T &item) {
std::ostringstream oss;
oss << item;
return oss.str();
}
We can make use of this template, with the compiler performing template-argument deduction, as follows:
int main() {
cout << to_string(Complex{ 3, 3.14 }) << endl;
cout << to_string(3.14) << endl;
cout << to_string(true) << endl;
}
This results in:
(3+3.14i)
3.14
1
If we then decide that the representation of a bool
is
undesirable, we can write a function overload as follows:
string to_string(bool item) {
return item ? "true" : "false";
}
Since this is a non-template function, C++ will prefer it to the
template instantiation to_string<bool>
when the argument type is
bool
. Thus, the same code in main()
now results in:
(3+3.14i)
3.14
true
SFINAE
In considering function overloads, the C++ compiler does not consider it an error if the types and expressions used in the header of a function template are unsuitable for a particular set of template arguments. This is known as substitution failure is not an error (SFINAE), and it is a powerful feature of templates in C++. Rather than producing an error in such a case, the compiler simply removes the template from the set of candidate functions to be considered in overload resolution.
As an example, suppose we wanted to modify our to_string()
to use
std::to_string()
for the types for which the latter is defined. We
can place a dependence on the existence of a suitable
std::to_string()
overload in the header of a new function
template:
template <class T>
auto to_string(const T &item) -> decltype(std::to_string(item)) {
return std::to_string(item);
}
Here, the trailing return type is necessary so that
std::to_string(item)
appears in the header of the function. Then
the function template will fail on substitution if there is no
overload of std::to_string()
such that it can be applied to a
value of the template argument. For example, consider calling our
to_string()
on a Complex
object:
cout << to_string(Complex{ 3, 3.14 }) << endl;
Our previous to_string()
template is still viable, so it is
considered in overload resolution. The new template we defined above,
however, fails to substitute, since there is no definition of
std::to_string()
that can be applied to a Complex
. Thus,
rather than being an error, the second template is merely removed from
consideration, and the call resolves to the original template.
With the second template definition, we can still call to_string()
on a bool
, since C++ will still prefer the non-template function.
However, we run into trouble when attempting to call it on a
double
:
to_string.cpp:82:11: error: call to 'to_string' is ambiguous
cout << to_string(3.14) << endl;
^~~~~~~~~~
to_string.cpp:65:8: note: candidate function [with T = double]
string to_string(const T &item) {
^
to_string.cpp:72:6: note: candidate function [with T = double]
auto to_string(const T &item) -> decltype(std::to_string(item)) {
^
to_string.cpp:76:8: note: candidate function
string to_string(bool item) {
^
1 error generated.
Both templates are equally viable when the argument is of type
double
, so the compiler cannot disambiguate between them. The
non-template overload that takes in a bool
is also viable, since a
double
can be converted to a bool
, so it is reported in the
error message even though it is inferior to either template.
In order to fix this problem, we need to arrange for the first
function template to be nonviable when there is a compatible overload
for std::to_string()
. This requires ensuring that there is a
substitution failure for the template when that is the case.
Ensuring a Substitution Failure
There are many tools that are used to ensure a substitution failure.
Perhaps the most fundamental is the enable_if
template, defined in
the standard library in the <type_traits>
header as of C++11. We
can also define it ourselves as follows:
template <bool B, class T>
struct enable_if {
typedef T type;
};
template <class T>
struct enable_if<false, T> {
};
The generic template takes in a bool
and a type and defines a
member alias for the type argument. The specialization elides this
alias when the bool
argument is false. C++14 additionally defines
enable_if_t
as an alias template, as in the following:
template <bool B, class T>
using enable_if_t = typename enable_if<B, T>::type;
We can use enable_if
or enable_if_t
to induce a failure, as in
the following definition for factorial
:
template <int N>
struct factorial {
static const std::enable_if_t<N >= 0, long long> value =
N * factorial<N - 1>::value;
};
When the template argument N
is negative, the enable_if
instantiation has no type
member, so we get an error:
In file included from factorial.cpp:1:0:
/opt/local/include/gcc5/c++/type_traits: In substitution of
'template<bool _Cond, class _Tp> using enable_if_t = typename
std::enable_if::type [with bool _Cond = false; _Tp = long long
int]':
factorial.cpp:36:52: required from 'struct factorial<-1>'
factorial.cpp:51:25: required from here
/opt/local/include/gcc5/c++/type_traits:2388:61: error: no type
named 'type' in 'struct std::enable_if<false, long long int>'
using enable_if_t = typename enable_if<_Cond, _Tp>::type;
^
factorial.cpp: In function 'int main()':
factorial.cpp:51:10: error: 'value' is not a member of 'factorial<-1>'
report<factorial<NUM>::value> a;
^
factorial.cpp:51:10: error: 'value' is not a member of 'factorial<-1>'
factorial.cpp:51:32: error: template argument 1 is invalid
report<factorial<NUM>::value> a;
^
This provides us another mechanism to prevent instantiation of a template with a semantically invalid argument. In this case, substitution failure is an error, since the failure did not occur in the header of a function template.
Another option we have is to rely on the fact that variadic arguments are the least preferred alternative in function-overload resolution. Thus, we can write our overloads as helper functions or function templates, with an additional argument to be considered in overload resolution:
string to_string_helper(bool item, int ignored) {
return item ? "true" : "false";
}
template <class T>
auto to_string_helper(const T &item, int ignored)
-> decltype(std::to_string(item)) {
return std::to_string(item);
}
template <class T>
string to_string_helper(const T &item, ...) {
std::ostringstream oss;
oss << item;
return oss.str();
}
template <class T>
string to_string(const T &item) {
return to_string_helper(item, 0);
}
Here, to_string()
calls to_string_helper()
with the item and a
dummy integer argument. We define three overloads of
to_string_helper()
as before, except that the overloads for
bool
and types for which std::to_string()
is defined take in
an extra int
argument. The generic overload that is viable for all
types, however, uses variadic arguments. Since variadic arguments have
the lowest priority in function-overload resolution, if both the
generic overload and another overload are viable, the latter is
chosen. Thus, the overload that uses std::to_string()
is preferred
when to_string_helper()
is called on a double
. We no longer
have an ambiguity, and we get the desired result when the program is
compiled and run:
(3+3.14i)
3.140000
true
Variadic Templates
As of the C++11 standard, C++ supports variadic templates, which are templates that take a variable number of arguments. Both class and function templates can be variadic, and variadic templates enable us to write variadic function overloads that are type safe, unlike C-style varargs.
As an example, consider the definition of a tuple template that encapsulates multiple items of arbitrary type. We can declare such a template as follows:
template <class... Types>
struct tuple;
The template parameter Types
is a parameter pack, which accepts
zero or more arguments. In this case, the ellipsis follows the
class
keyword, so the arguments accepted by the parameter pack are
types. We can then declare a tuple as follows:
tuple<> t0;
tuple<int> t1;
tuple<double, char, int> t2;
In the first instantiation, the parameter pack Types
is empty,
since no template arguments were provided. In the second
instantiation, Types
is associated with the single argument
int
, and in the last case, Types
is associated with the three
arguments double
, char
, and int
.
Within the template definition, we can use the sizeof...
operator
to determine the size of the parameter pack. Thus, we can compute the
size of the tuple as:
static const int size = sizeof...(Types);
Parameter packs are often processed recursively. It is natural to define a tuple itself recursively as a combination of the first data item and a smaller tuple containing all but the first. The following is a specialization for a non-empty tuple (i.e. a tuple with at least one element):
template <class First, class... Rest>
struct tuple<First, Rest...> {
static const int size = 1 + sizeof...(Rest);
using first_type = First;
using rest_type = tuple<Rest...>;
first_type first;
rest_type rest;
// ...
};
The ellipsis, when it appears to the right of a pattern containing a
parameter pack, expands the pattern into comma-separated
instantiations of the pattern, one per item in the parameter pack.
Thus, if First
is associated with double
and Rest
is
associated with char
and int
, tuple<First, Rest...>
expands to tuple<double, char, int>
.
In the code above, we have introduced type aliases for the type of the first data item and the type of the rest of the tuple. We then declared data members for each of these components. We can write a constructor to initialize them as follows:
tuple(First f, Rest... r) : first(f), rest(r...) {}
With First
as double
and Rest
as above, this expands to
the equivalent of:
tuple(double f, char r0, int r1) :
first(f), rest(r0, r1) {}
Both the parameter Rest... r
as well as the use of the parameter
r...
expand, with r
replaced by a unique identifier in each
instantiation of the pattern.
The full definition of the template specialization is as follows:
template <class First, class... Rest>
struct tuple<First, Rest...> {
static const int size = 1 + sizeof...(Rest);
using first_type = First;
using rest_type = tuple<Rest...>;
first_type first;
rest_type rest;
tuple(First f, Rest... r) : first(f), rest(r...) {}
};
Since this is a recursive definition, we need a base case to terminate the recursion. It is natural to choose an empty tuple as the base case. We can define this with another specialization:
template <>
struct tuple<> {
static const int size = 0;
};
To facilitate using a tuple, we can write a function template to
construct a tuple. This can then take advantage of argument deduction
for function templates, which is not available for class templates
prior to C++17. We write a make_tuple
variadic function template
as follows:
template <class... Types>
tuple<Types...> make_tuple(Types... items) {
return tuple<Types...>(items...);
}
We can now make use of this function template to construct a tuple:
tuple<int> t1 = make_tuple(3);
tuple<double, char, int> t2 = make_tuple(4.9, 'c', 3);
While we now have the ability to construct a tuple, we have not yet provided a convenient mechanism for accessing individual elements from a tuple. To do so, we can write a function template as follows:
template<int Index, class Tuple>
auto &get(Tuple &tup) {
static_assert(Index >= 0 and Index < Tuple::size, "bad index");
if constexpr (Index == 0) {
return tup.first;
} else {
return get<Index - 1>(tup.rest);
}
}
Here, we make use of C++14’s return type deduction to avoid computing the type of a tuple element ourselves. In addition, we use a C++17 constexpr if statement to handle both the recursive and base case within a single function template. Such a conditional requires a compile-time constant as the test, and the conditional is resolved at compile time.
To use get()
, we need to explicitly provide a value for the
Index
template parameter, since it cannot be deduced from the
function arguments. We can rely on argument deduction for the second
Tuple
template parameter:
tuple<double, char, int> t2 = make_tuple(4.9, 'c', 3);
cout << get<0>(t2) << endl;
cout << get<1>(t2) << endl;
cout << get<2>(t2) << endl;
++get<0>(t2);
++get<1>(t2);
++get<2>(t2);
cout << get<0>(t2) << endl;
cout << get<1>(t2) << endl;
cout << get<2>(t2) << endl;
This results in:
4.9
c
3
5.9
d
4
The standard library provides a definition of tuple
along with
make_tuple()
and get()
in the <tuple>
header.
Alternate Pre-C++14 Implementation
Without a constexpr if statement, we would need to write separate
function templates for the recursive case of get()
with Index >=
0
and for the base case of Index == 0
. However, since C++
function templates do not allow partial specialization, we would need
to wrap these cases within a class template, which we can then
partially specialize. We proceed to do so to demonstrate this
technique.
To start off, we first write a class template to contain a reference to a single element from a tuple. We declare it as follows:
template <int Index, class Tuple>
struct tuple_element;
The parameter Index
is the index corresponding to the item
referenced by a tuple_element
, and Tuple
is the type of the
tuple itself. We can then write the base case as follows:
template <class Tuple>
struct tuple_element<0, Tuple> {
using type = typename Tuple::first_type;
type &item;
tuple_element(Tuple &t) : item(t.first) {}
};
The type of the element at index 0 is aliased by the first_type
member of a tuple. The element itself is represented by the first
data member of a tuple object. Thus, we initialize our reference to
the item with the first
member of the tuple argument to the
constructor. We also introduce a type alias type
to refer to the
type of the item.
The recursive case decrements the index and passes off the computation
to a tuple_element
instantiated with all but the first item in a
tuple:
template <int Index, class Tuple>
struct tuple_element {
using rest_type = tuple_element<Index - 1,
typename Tuple::rest_type>;
using type = typename rest_type::type;
type &item;
tuple_element(Tuple &t) : item(rest_type(t.rest).item) {}
};
The rest_type
member alias of a tuple is the type representing all
but the first item in the tuple. We alias rest_type
in
tuple_element
to recursively refer to a tuple_element
with a
decremented index and the rest_type
of the tuple. We then arrange
to retrieve the item from this recursive instantiation. The
constructor creates a smaller tuple_element
and initializes
item
to refer to the item contained in the smaller
tuple_element
. We similarly alias type
to refer to the
type
contained in the smaller tuple_element
.
The following is an alias template for the type of a tuple element:
template <int Index, class Tuple>
using tuple_element_t = typename tuple_element<Index, Tuple>::type;
We can now write a function template to retrieve an item out of a tuple:
template <int Index, class... Types>
tuple_element_t<Index, tuple<Types...>> &get(tuple<Types...> &t) {
return tuple_element<Index, tuple<Types...>>(t).item;
}
The work is offloaded to the tuple_element
class template, out of
which we retrieve both the type of the element and the element itself.
Example: Multidimensional Arrays
As an extended example of using metaprogramming to build a complex system, let’s consider the implementation of a multidimensional array library in C++. Built-in C++ arrays are very limited: they represent only a linear sequence of elements, and they do not carry any size information. Multidimensional arrays can be represented by arrays of arrays, but this representation can be cumbersome to use and can suffer from poor spatial locality. Instead, most applications linearize a multidimensional array and map a multidimensional index to a linear index. We will use this strategy, but we will abstract the translation logic behind an ADT interface.
Points
We start with an abstraction for a multidimensional index, which we call a point. A point consists of a sequence of integer indices, such as \((3, 4, 5)\) for a three-dimensional index. We define a point template as follows:
template <int N>
struct point {
int coords[N];
int &operator[](int i) {
return coords[i];
}
const int &operator[](int i) const {
return coords[i];
}
};
The template is parameterized by the dimensionality of the point, and
its data representation is an array of coordinates. We overload the
index operator for both const
and non-const
points.
We provide a stream-insertion operator overload as follows:
template <int N>
std::ostream &operator<<(std::ostream &os, const point<N> &p) {
os << "(" << p[0];
for (int i = 1; i < N; i++) {
os << "," << p[i];
}
return os << ")";
}
In order to work with points, it is useful to have point-wise arithmetic operations on points, as well as comparison operators. For instance, the following are possible definitions of addition and equality:
template <int N>
point<N> operator+(const point<N> &a, const point<N> &b) {
point<N> result;
for (int i = 0; i < N; i++)
result[i] = a[i] + b[i];
return result;
}
template <int N>
bool operator==(const point<N> &a, const point<N> &b) {
bool result = true;
for (int i = 0; i < N; i++)
result = result && (a[i] == b[i]);
return result;
}
There is a lot of similarity between these two functions: they share the same template header, arguments, and overall body structure, with an initial value, a loop to update the value, and a return of that value. Rather than writing several arithmetic and comparison operations with this structure, we can use a function-like macro to abstract the common structure:
#define POINT_OP(op, rettype, header, action, retval) \
template <int N> \
rettype operator op(const point<N> &a, const point<N> &b) { \
header; \
for (int i = 0; i < N; i++) \
action; \
return retval; \
}
Then an arithmetic operators such as +
or -
can be defined as
follows:
POINT_OP(+, point<N>, point<N> result,
result[i] = a[i] + b[i], result);
POINT_OP(-, point<N>, point<N> result,
result[i] = a[i] - b[i], result);
These in turn are very similar, with the only difference the two
occurrences of +
or -
. We can then abstract this structure
further for arithmetic operations:
#define POINT_ARITH_OP(op) \
POINT_OP(op, point<N>, point<N> result, \
result[i] = a[i] op b[i], result)
Similarly, we can abstract the structure for comparison operations:
#define POINT_COMP_OP(op, start, combiner) \
POINT_OP(op, bool, bool result = start, \
result = result combiner (a[i] op b[i]), result)
We can now use these macros to define the point operations:
POINT_ARITH_OP(+);
POINT_ARITH_OP(-);
POINT_ARITH_OP(*);
POINT_ARITH_OP(/);
POINT_COMP_OP(==, true, &&);
POINT_COMP_OP(!=, false, ||);
POINT_COMP_OP(<, true, &&);
POINT_COMP_OP(<=, true, &&);
POINT_COMP_OP(>, true, &&);
POINT_COMP_OP(>=, true, &&);
Compared to writing ten separate functions, this strategy has much less repetition.
One last operation that would be useful is to construct a point of the
desired dimensionality from a sequence of coordinates, analogous to
make_tuple()
from the previous section. We can define a variadic
function to do so as follows, giving it the name pt()
for
succinctness:
template <class... Is>
point<sizeof...(Is)> pt(Is... is) {
return point<sizeof...(Is)>{{ is... }};
}
We use the sizeof...
operator to compute the dimensionality. The
nested initializer lists are required, the outer one for the point
struct itself and the inner one for initializing its coords
member, since the latter is an array.
We can now perform operations on points:
cout << (pt(3, 4) + pt(1, -2)) << endl;
cout << (pt(1, 2, 3) < pt(3, 4, 5)) << endl;
This results in:
(4,2)
1
Domains
The domain of an array is the set of points that it maps to elements. A domain is rectangular if the start and end index for each dimension is independent of the indices for the other dimensions. Thus, an array over a rectangular domain maps a rectangular region of space to elements.
We can represent a rectangular domain by an inclusive lower-bound point and an exclusive upper-bound point:
template <int N>
struct rectdomain {
point<N> lwb; // inclusive lower bound
point<N> upb; // exclusive upper bound
// Returns the number of points in this domain.
int size() const {
if (!(lwb < upb))
return 0;
int result = 1;
for (int i = 0; i < N; i++) {
// multiple by the span of each dimension
result *= upb[i] - lwb[i];
}
return result;
}
};
We can define an iterator over a rectangular domain as follows,
writing it as a nested class within the rectdomain
template:
template <int N>
struct rectdomain {
...
struct iterator {
point<N> lwb; // inclusive lower bound
point<N> upb; // inclusive upper bound
point<N> current; // current item
// Returns the current point.
point<N> operator*() const {
return current;
}
// Moves this iterator to the next point in the domain.
iterator &operator++() {
// Increment starting at the last dimension.
for (int i = N - 1; i >= 0; i--) {
current[i]++;
// If this dimension is within bounds, then we are done.
if (current[i] < upb[i])
return *this;
// Otherwise, reset this dimension to its minimum and move
// on to the previous one.
current[i] = lwb[i];
}
// We ran out of dimensions to increment, set this to an end
// iterator.
current = upb;
return *this;
}
bool operator==(const iterator &rhs) const {
return current == rhs.current;
}
bool operator!=(const iterator &rhs) const {
return !operator==(rhs);
}
};
// Return an iterator that is set to the inclusive lower-bound
// point.
iterator begin() const {
return iterator{ lwb, upb, lwb };
}
// Return an iterator that is set to the exclusive upper-bound
// point.
iterator end() const {
return iterator{ lwb, upb, upb };
}
};
The iterator keeps track of the lower and upper bounds, as well as the current point. Incrementing an iterator increments the last coordinate of the current point, and if that reaches the upper bound for that coordinate, it is set to the lower bound and the previous coordinate is incremented instead. This process is repeated as necessary, and if the first coordinate reaches its upper bound, the iterator reaches the end.
We can now use rectangular domains as follows:
for (auto p : rectdomain<3>{ pt(1, 2, 3), pt(3, 4, 5) })
cout << p << endl;
This results in:
(1,2,3)
(1,2,4)
(1,3,3)
(1,3,4)
(2,2,3)
(2,2,4)
(2,3,3)
(2,3,4)
Arrays
We can now proceed to define an ADT for a multidimensional array. We can represent it with a rectangular domain and a C++ array to store the elements. We also keep track of the size of each dimension for the purposes of index computations. The following is an implementation:
template <class T, int N>
struct ndarray {
rectdomain<N> domain; // domain of this array
int sizes[N]; // cached size of each dimension
T *data; // storage for the elements
// Constructs an array with the given domain, default initializing
// the elements.
ndarray(const rectdomain<N> &dom)
: domain(dom), data(new T[dom.size()]) {
// Compute and store sizes of each dimension.
for (int i = 0; i < N; i++) {
sizes[i] = domain.upb[i] - domain.lwb[i];
}
}
// Copy constructor does a deep copy.
ndarray(const ndarray &rhs)
: domain(rhs.domain), data(new T[domain.size()]) {
std::copy(rhs.data, rhs.data + domain.size(), data);
std::copy(rhs.sizes, rhs.sizes + N, sizes);
}
// Assignment operator does a deep copy.
ndarray &operator=(const ndarray &rhs) {
if (&rhs == this)
return *this;
delete[] data;
domain = rhs.domain;
data = new T[domain.size()];
std::copy(rhs.data, rhs.data + domain.size(), data);
std::copy(rhs.sizes, rhs.sizes + N, sizes);
return *this;
}
// Destructor deletes the underlying storage and the elements
// within.
~ndarray() {
delete[] data;
}
// Translates a multidimensional point index into a
// single-dimensional index into the storage array.
int indexof(const point<N> &index) const;
// Returns the element at the given multidimensional index.
T &operator[](const point<N> &index) {
return data[indexof(index)];
}
// Returns the element at the given multidimensional index.
const T &operator[](const point<N> &index) const {
return data[indexof(index)];
}
};
The class template is parameterized by the element type and dimensionality. A constructor takes in a rectangular domain, allocates an underlying array of the appropriate size to hold the elements, and stores the size of each dimension. The Big Three are implemented as needed. (We elide the move constructor and move assignment operator for simplicity.) We then have a function to translate a multidimensional index into a linear one, which the overloaded index operators use to obtain an element.
The indexof()
function uses the combination of the input point and
the size of each dimension to linearize the index. In our
representation, the array is stored in row-major format, so that the
last dimension is the contiguous one:
template <class T, int N>
int ndarray<T, N>::indexof(const point<N> &index) const {
int result = index[0] - domain.lwb[0];
for (int i = 1; i < N; i++) {
result = result * sizes[i-1] + (index[i] - domain.lwb[i]);
}
return result;
}
Since the value of N
is a compile-time constant, this loop can be
trivially unrolled by the compiler, eliminating any branching and
resulting in a faster computation.
Stencil
We can now use arrays to perform a stencil computation, which iteratively computes the value of a grid point based on its previous value and the previous values of its neighbors. Figure 48 is an example of a stencil update associated with Conway’s Game of Life, on a \(3 \times 3\) grid.
We use two grids, one for the previous timestep and one for the current one. We use ghost regions at the edges of the grids, extending each edge by an extra point, to avoid having to do separate computations at the boundaries.
The following constructs three-dimensional grids of size \(xdim \times ydim \times zdim\), with ghost regions:
rectdomain<3> domain{ pt(-1, -1, -1), pt(xdim+1, ydim+1, zdim+1) };
rectdomain<3> interior{ pt(0, 0, 0), pt(xdim, ydim, zdim) };
ndarray<double, 3> gridA(domain);
ndarray<double, 3> gridB(domain);
We initialize the grids as needed and then perform an iterative stencil computation as follows:
void probe(ndarray<double, 3> *gridA_ptr,
ndarray<double, 3> *gridB_ptr,
const rectdomain<3> &interior, int steps) {
for (int i = 0; i < steps; i++) {
ndarray<double, 3> &gridA = *gridA_ptr;
ndarray<double, 3> &gridB = *gridB_ptr;
for (auto p : interior) {
gridB[p] =
gridA[p + pt( 0, 0, 1)] +
gridA[p + pt( 0, 0, -1)] +
gridA[p + pt( 0, 1, 0)] +
gridA[p + pt( 0, -1, 0)] +
gridA[p + pt( 1, 0, 0)] +
gridA[p + pt(-1, 0, 0)] +
WEIGHT * gridA[p];
}
// Swap pointers
std::swap(gridA_ptr, gridB_ptr);
}
}
We make use of iteration over a rectangular domain, arithmetic over points, and using points to index into the multidimensional array. At the end of each timestep, we swap which grid is the current and which is the previous.
While this code is simple to write, it does not perform well on many compilers. The linearized iteration over the rectangular domain can prevent a compiler from optimizing the iteration order to make the best use of the memory hierarchy, such as with a polyhedral analysis. In GCC, for example, we find that a nested loop structure such as the following can be five times more efficient:
for (p[0] = interior.lwb[0]; p[0] < interior.upb[0]; p[0]++) {
for (p[1] = interior.lwb[1]; p[1] < interior.upb[1]; p[1]++) {
for (p[2] = interior.lwb[2]; p[2] < interior.upb[2]; p[2]++) {
gridB[p] =
gridA[p + pt( 0, 0, 1)] +
gridA[p + pt( 0, 0, -1)] +
gridA[p + pt( 0, 1, 0)] +
gridA[p + pt( 0, -1, 0)] +
gridA[p + pt( 1, 0, 0)] +
gridA[p + pt(-1, 0, 0)] +
WEIGHT * gridA[p];
}
}
}
This code is less simple, and it introduces a further dependency on the dimensionality of the grid, preventing us from generalizing it to an arbitrary number of dimensions.
Nested Iteration
In order to solve the problem of linearized iteration, we can use metaprogramming to turn what appears to be a single loop into a nested one, making it more amenable to analysis and optimization. We start by writing a recursive template that introduces a loop nest at each level of the recursion:
template <int N>
struct rdloop {
// Performs a nested loop over the set of loop indices in [lwb,
// upb). The size of lwb and upb must be at least N. For each
// index i1, ..., iN in [lwb, upb), calls func on the point
// pt(is..., i1, ..., iN).
template <class Func, class... Indices>
static void loop(const Func &func, const int *lwb,
const int *upb, Indices... is) {
for (int i = *lwb; i < *upb; i++) {
rdloop<N-1>::loop(func, lwb+1, upb+1, is..., i);
}
}
};
We write our template as a class, since we will require a base case and would need partial function-template specialization, which is not supported by C++, to implement it purely with function templates. The class is parameterized by the dimensionality. Within the class is a single static member function template that is parameterized by a functor type and a variadic set of indices. The arguments to the function itself are a functor object, which will be applied in the innermost loop, lower and upper bounds for the remaining dimensions, and the set of indices computed so far.
The body introduces a new loop nest, using the lower and upper bounds, and recursively applies the template with one less dimension. The bound pointers are adjusted for the new dimension, and we pass the input indices along with the one for this dimension in the recursive call. Our base case, where there is only a single dimension, is then as follows:
template <>
struct rdloop<1> {
template <class Func, class... Indices>
static void loop(const Func &func, const int *lwb,
const int *upb, Indices... is) {
for (int i = *lwb; i < *upb; i++) {
func(pt(is..., i));
}
}
};
We construct a point from the collected set of indices from each dimension and then call the functor object on that point.
Now that we have a mechanism for constructing a set of nested loops, we start the recursion from a function object and domain as follows:
rdloop<N>::loop(func, domain.lwb.coords,
domain.upb.coords);
In order to actually make use of this, we provide a loop abstraction as follows:
foreach (p, interior) {
gridB[p] =
gridA[p + pt( 0, 0, 1)] +
gridA[p + pt( 0, 0, -1)] +
gridA[p + pt( 0, 1, 0)] +
gridA[p + pt( 0, -1, 0)] +
gridA[p + pt( 1, 0, 0)] +
gridA[p + pt(-1, 0, 0)] +
WEIGHT * gridA[p];
};
We have the foreach
keyword, which we will define shortly, that
takes in a variable name to represent a point and the domain over
which to iterate. We then have a loop body that uses the point
variable. A semicolon appears after the body, and it is necessary due
to how foreach
is defined.
The loop body looks very much like the body of a lambda function, and since we require a function object in order to build the nested structure, it is natural to consider how we can arrange for the loop body to turn into a lambda function. We need a statement in which a lambda function can appear at the end, right before the terminating semicolon, and assignment fits this structure:
<var> = [<capture>](<parameters>) {
<body>
};
Thus, we need to arrange for the foreach
header to turn into the
beginning of this statement:
<var> = [<capture>](<parameters>)
We would like the programmer to be able to use all local variables, so
we should capture all variables by reference. The foreach
also
introduces a new variable for the point, so that should be in the
parameter list:
<var> = [&](const point<N> &<name>)
There are several remaining things we need. First, we need to figure
out the dimensionality of the point to use as the parameter. We can
use decltype
to do so from the domain:
<var> = [&](const decltype(<domain>.lwb) &<name>)
Second, we need a way to ensure that when this assignment happens, the
nested loop structure is executed. We can do so by overloading the
assignment operator of the object <var>
. Finally, we also need to
introduce the left-hand variable, preferably in its own scope. We can
do both by introducing a dummy loop header:
#define foreach(p, dom) \
for (auto _iter = (dom).iter(); !_iter.done; _iter.done = true) \
_iter = [&](const decltype((dom).lwb) &p)
In order for this to work, we need the iter()
method on a domain
to give us an object whose assignment operator takes in a functor. This
operator would then call the functor within a nested set of loops. The
object also needs a done
field in order to ensure the dummy loop
executes exactly one iteration. We can add the following members to
the rectdomain
template:
template <int N>
struct rectdomain {
...
struct fast_iter {
const rectdomain &domain; // domain over which to iterate
bool done; // whether or not this loop has run
// Constructs a fast_iter with the given domain.
fast_iter(const rectdomain &dom)
: domain(dom), done(false) {}
// Loops over the associate domain, calling func on each point
// in the domain.
template <class Func>
fast_iter &operator=(const Func &func) {
rdloop<N>::loop(func, domain.lwb.coords,
domain.upb.coords);
return *this;
}
};
// Returns a fast_iter over this domain.
fast_iter iter() const {
return fast_iter(*this);
}
};
The assignment operator of fast_iter
is a template, taking in a
functor object. It then uses our nested loop generation mechanism to
generate a set of nested loops and call the functor from the innermost
loop, with the appropriate point as the argument.
The result is a loop that has the simplicity of a range-based for loop
but, depending on the compiler, the performance of a nested set of
loops. As an example, with GCC 6.2 on the author’s iMac computer, the
range-based for loop takes 1.45 seconds to perform ten timesteps of
the stencil above on a \(256^3\) grid, while the nested loops and
the foreach
loop each take 0.28 seconds. This demonstrates the
power of metaprogramming in order to extend the features of a
language.