Extended Polymerization - Memoization
- Divide and Conquer
- Memoization
- Dynamic Programming
- Benchmarking & Analysis (in progress)
In the previous post, we saw that our recursive solution to counting characters still scaled exponentially with the number of string expansions we needed to do. Although we had saved on memory, runtime performance was still too slow to solve "Part 2" of the Advent of Code problem that requires performing 40 iterations of string expansion.
Fortunately, there is a straightforward technique to improve our algorithm's runtime.
Overlapping Subproblems
Our recursive algorithm uses a "divide-and-conquer"[1] approach. Each recursive call to the function splits its problem into 2 smaller subproblems, asks for the answer to the 2 smaller subproblems, and combines those results to get its own answer.
Our divide-and-conquer algorithm has a particular characteristic that contributes to its poor performance. Because of the branching structure of the recursion, our function will have to evaluate the answer for the same arguments multiple times. In fact, the more iterations we try to perform, the more this occurs. This is called "overlapping subproblems".
For example, imagine a recursive function that works by adding the results from 2 subproblems. We can show the recursive relationship between function calls as a tree, where each node represents a call to the function with a particular set of arguments. If we imagine that sometimes two calls will ask for the same subproblem, the overlapping subproblems could look like this:
In this example the subproblems overlap almost perfectly, so that we can see the overlapping effect without having to draw many layers of the tree. If we trace the order that each function is called, we can see how many times each is evaluated:
The node E is evaluated twice. And at the next layer down, H and I are each evaluated 3 times. This means that to compute the result for node A, we are performing 5 redundant function calls - we passed the same arguments multiple times, and got the same subproblem result multiple times. If we continued to extend the tree downwards, this overlap would increase with each additional layer.
We can guess that our algorithm has overlapping subproblems based on intuition and its runtime performance. But we can also see overlaps by writing out or drawing the recursion tree for our character pairs. Here is an example from the previous post:
- QQ → E
- QW → W
- QE → Q
- WQ → W
- WW → E
- WE → Q
- EQ → Q
- EW → E
- EE → W
countChars("QW", 9) = countChars("QW", 8) + countChars("WW", 8)
countChars("WW", 9) = countChars("WE", 8) + countChars("EW", 8)
countChars("QW", 8) = countChars("QW", 7) + countChars("WW", 7)
countChars("WW", 8) = countChars("WE", 7) + countChars("EW", 7)
countChars("WE", 8) = countChars("WQ", 7) + countChars("QE", 7)
countChars("EW", 8) = countChars("EE", 7) + countChars("EW", 7)
...
Already after expanding a 3-character string 2 rounds, we can see a repeated function call: countChars("EW", 7)
is used for both ("WW", 8)
and ("EW", 7)
. The full problem input uses a longer starting string and more possible characters, but the same thing is highly likely to happen, especially after many rounds of expansion.
Memoization Concept
Memoization is an optimization technique that works by storing the results of previous function calls.[2] When the memoized function computes a result, that result is stored in a cache using the arguments passed to the function as a unique key. Whenever the function is called, it first checks the cache to see if it has been called before with the same arguments. If so, instead of computing the same result again from scratch, it can immediately return the result. In pseudocode, the pattern looks like this:
func memoizedExpensiveFunction(arguments)
key := create_unique_key(arguments)
if key in cache then
return cache[key]
else
result := actuallyDoSlowExpensiveThing(arguments)
cache[key] := result
return result
end
end
This pattern can be useful for functions with these characteristics:
- It is "pure" function - i.e. whenever the function is called with the same arguments, it will return the same result and does not perform side effects
- It will be called many times with the same arguments
- The program can afford to use the memory needed for the cached results
Memoizing a procedure trades space - memory to store previous results - for time - not having to calculate the same result repeatedly. It depends on the particular function and program requirements whether that tradeoff will be an improvement. For example:
- The more unique arguments the function is called with, the larger the cache will need to be
- If the function is rarely called more than once with the same arguments, the cache will have little benefit
Implementing Memoization
Memoizing our recursive algorithm will require only a few small changes, but first we need to identify how to store our intermediate results. Our premise is that if our algorithm is called with the same arguments, it will always return the same result. We want a way to store the result for a particular set of arguments, and retrieve that stored result if our function is called again with the same arguments in the future. This sounds like an associative array - given a key, it returns the value associated with that key (if one exists). Most standard libraries include a data structure like this under some name, e.g. maps (F#), dicts (Python), tables (Lua), objects (JavaScript), or hashmaps (Java).
F# on .NET provides 2 options for this:
Dictionary<K, V>
, from .NET'sSystem.Collections.Generic
Map<K, V>
, from the F# standard library
Normally in F# I would lean toward the standard library as it follows functional patterns instead of object-oriented ones. But in this case, it is important that all the recursive function calls share the same table of results, and the F# data structures have immutable semantics - when you add a value to a Map
, you get a new Map with the value included, while the original map remains unchanged. Dictionary
is an object oriented data structure that we can freely mutate the contents of, and is passed by reference.
type Cache = Dictionary<?, ?>
Now we can determine the types of the keys and values to store in the dictionary. The values must match whatever our function returns, and the key must be related to the function's arguments. The signature of our recursive function is like this:
let rec countForPair (rules : Rules) (pair : Pair) (iterations : int) : Counter<char> =
// ...
Since it returns a Counter<char>
, that is the type of the values kept in the dictionary:
type Cache = Dictionary<?, Counter<char>>
The key for each entry must be unique for calls that will return the same result. Usually this means combining the function's arguments into a string or integer key. A simple way to do this in many programming languages is to use a standard hash function[3] to convert all the argument values into a unique integer. There are standard algorithms for generating hash values from multiple values of different types[4], and .NET includes one of these for all its basic types: Object.GetHashCode method.
For our particular problem we can make things a little easier by ignoring the rules
parameter, since the function will always be called with the same rule set. Then we can use the tuple implementation of GetHashCode to get unique keys for different sets of function arguments:
let getCacheKey (a : char, b : char) (n : int) : int =
(a, b, n).GetHashCode()
Lastly, we now have the full type for our cache:
type Cache = Dictionary<int, Counter<char>>
With these pieces in place, memoizing our previous recursive algorithm is as simple as the pseudocode from earlier. Major changes in the following code are marked with NEW
:
module Memoized =
// NEW
type Cache = Dictionary<int, Counter<char>>
// NEW
let getCacheKey (a : char, b : char) (n : int) : int =
(a, b, n).GetHashCode()
let rec countForPair (cache : Cache) (rules : Rules) (pair : Pair) (iterations : int) =
// NEW
let key = getCacheKey pair iterations
// NEW
if cache.ContainsKey(key) then
cache[key]
else
let leftChar, rightChar = pair
if iterations = 0 then
// Same as "plain" recursive version
Counter.empty ()
|> Counter.incr leftChar
|> Counter.incr rightChar
else
// Almost identical to plain recursive version;
// We pass the cache object to recursive calls.
let sharedChar = Map.find pair rules
let leftSubResult = countForPair cache rules (leftChar, sharedChar) (iterations - 1)
let rightSubResult = countForPair cache rules (sharedChar, rightChar) (iterations - 1)
// NEW
// Instead of returning a result immediately, store it in the cache
// for later calls to re-use:
let result =
Counter.add leftSubResult rightSubResult
|> Counter.decr sharedChar
cache[key] <- result
result
let countAllCharacters (rules : Rules) (template : char list) (iterations : int) =
// NEW
// Create a cache object - this will be passed by reference all the way
// down the chain of recursion, so later recursive calls can re-use
// results calculated by earlier ones.
let sharedCache = Cache()
// Only one change - pass the shared cache object to countForPair.
let countWithOverlap (index : int, pair : Pair) =
let pairResult = countForPair sharedCache rules pair iterations
if index = 0 then
pairResult
else
Counter.decr (fst pair) pairResult
// same as original
template
|> List.windowed 2
|> List.map (fun xs -> xs[0], xs[1]) // convert 2-element arrays into tuples
|> List.indexed
|> List.map countWithOverlap
|> List.fold Counter.add (Counter.empty ())
Testing the Memoized Version
We can run our updated algorithm in FSI using a helper function similar to the one we've used previously:
let countMemoized n =
sampleString
|> String.toCharList
|> fun cs ->
Memoized.countAllCharacters sampleRules cs n
|> Seq.map (fun kv -> kv.Key, kv.Value)
|> Seq.toList
In fact, the two test functions only differ by which version of countAllCharacters
is called, so we can combine the two into one that takes the function to solve with as a parameter:
let countWithFunction solver n =
sampleString
|> String.toCharList
|> fun cs ->
solver sampleRules cs n
|> Seq.map (fun (kv: KeyValuePair<char, int64>) -> kv.Key, kv.Value)
|> Seq.toList
let countRecursive = countWithFunction Recursive.countAllCharacters
let countMemoized = countWithFunction Memoized.countAllCharacters
Lets run it:
> #load "ExtendedPolymerization.fsx";;
[Loading C:\Users\mrsei\source\advent of code\2021\ExtendedPolymerization.fsx]
module FSI_0002.ExtendedPolymerization
...
> open FSI_0002.ExtendedPolymerization;;
> #time "on";;
--> Timing now on
> countRecursive 10;;
Real: 00:00:00.002, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list = [('E', 778L); ('Q', 915L); ('W', 356L)]
> countMemoized 10;;
Real: 00:00:00.000, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list = [('E', 778L); ('Q', 915L); ('W', 356L)]
So far so good! Our two implementations are returning the same results. And even for only 10 iterations, the memoized version seems to be completing faster than the original. Checking for larger numbers of iterations:
> countRecursive 20;;
Real: 00:00:01.748, CPU: 00:00:01.718, GC gen0: 314, gen1: 4, gen2: 1
val it: (char * int64) list =
[('E', 741812L); ('Q', 1198035L); ('W', 157306L)]
> countMemoized 20;;
Real: 00:00:00.000, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list =
[('E', 741812L); ('Q', 1198035L); ('W', 157306L)]
> countRecursive 25;;
Real: 00:00:54.182, CPU: 00:00:54.156, GC gen0: 9996, gen1: 8, gen2: 1
val it: (char * int64) list =
[('E', 23269761L); ('Q', 40527870L); ('W', 3311234L)]
> countMemoized 25;;
Real: 00:00:00.000, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list =
[('E', 23269761L); ('Q', 40527870L); ('W', 3311234L)]
For 25 iterations the memoized version is still completing quickly enough that the FSI timer can't measure it. How high can we go?
> countMemoized 30;;
Real: 00:00:00.000, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list =
[('E', 734775626L); ('Q', 1343007351L); ('W', 69700672L)]
> countMemoized 40;;
Real: 00:00:00.001, CPU: 00:00:00.000, GC gen0: 1, gen1: 0, gen2: 0
val it: (char * int64) list =
[('E', 741403356022L); ('Q', 1426736052417L); ('W', 30883847114L)]
> countMemoized 50;;
Real: 00:00:00.001, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list =
[('E', 754319967443786L); ('Q', 1483795444085587L); ('W', 13684402155876L)]
> countMemoized 60;;
Real: 00:00:00.001, CPU: 00:00:00.000, GC gen0: 0, gen1: 0, gen2: 0
val it: (char * int64) list =
[('E', 770262653733359560L); ('Q', 1529516899017849463L);
('W', 6063456462484930L)]
Aside - Integer Overflow
Our new implementation seems to be handling the scale part 2 of our original problem and beyond with relative ease. We do eventually hit a problem, but it isn't the running time:
> countMemoized 70;;
System.Exception: tried to decrement count for key 'Q' but it was missing or zero
at Microsoft.FSharp.Core.PrintfModule.PrintFormatToStringThenFail@1448.Invoke(String message)
at FSI_0002.ExtendedPolymerization.Counter.decr[t](t key, Dictionary`2 counter) in C:\..\ExtendedPolymerization.fsx:line 60
at FSI_0002.ExtendedPolymerization.Memoized.countForPair(Dictionary`2 cache, FSharpMap`2 rules, Char leftChar, Char rightChar, Int32 iterations) in C:\..\ExtendedPolymerization.fsx:line 165
...removed repeats from recursion
at FSI_0002.ExtendedPolymerization.Memoized.countForPair(Dictionary`2 cache, FSharpMap`2 rules, Char leftChar, Char rightChar, Int32 iterations) in C:\..\ExtendedPolymerization.fsx:line 164
at FSI_0002.ExtendedPolymerization.Memoized.countWithOverlap@184-1.Invoke(Tuple`2 tupledArg) in C:\..\ExtendedPolymerization.fsx:line 186
at FSI_0002.ExtendedPolymerization.Memoized.countAllCharacters(FSharpMap`2 rules, FSharpList`1 template, Int32 iterations) in C:\..\ExtendedPolymerization.fsx:line 180
at FSI_0002.ExtendedPolymerization.countMemoized@235.Invoke(FSharpMap`2 rules, FSharpList`1 template, Int32 iterations)
at FSI_0002.ExtendedPolymerization.countWithFunction[a,b](FSharpFunc`2 solver, a n) in C:\..\ExtendedPolymerization.fsx:line 227
at FSI_0002.ExtendedPolymerization.countMemoized@235-1.Invoke(Int32 n)
at <StartupCode$FSI_0026>.$FSI_0026.main@() in C:\..\stdin:line 27
Stopped due to error
This error is particular to my implementation of Counter
from part 1. The important parts of the code are as follows:
type Counter<'t when 't: comparison> = Dictionary<'t, int64>
module Counter =
// ...
let decr (key : 't) (counter : Counter<'t>) =
let currentValue = counter.GetValueOrDefault(key, 0L)
if currentValue > 0L then
counter[key] <- currentValue - 1L
else
failwithf "tried to decrement count for key %A but it was missing or zero" key
counter
There is the error message, created by failwithf
. Two elements lead to the error:
- The
Counter
type is storing results usingint64
, which is a signed integer - it can be negative. Counter.decr
's error message says thatcurrentValue
was "missing or zero", but what it actually checks for is that currentValue is not less than or equal to zero.
How might the values become negative numbers, if we do not allowing decrementing them unless they are greater than 0?
All fixed-size integer types have minimum and maximum value. The size of those values depends on the number of bits used by the type - hence int64
can hold larger values than int32
, etc. By default[5], operations that would exceed the maximum value of an integer result in an integer overflow, and the value will wrap around from the one end of its range to the other:
> System.Int32.MaxValue;;
val it: int = 2147483647
> System.Int32.MaxValue + 1;;
val it: int = -2147483648
.NET's System.Int64
has a maximum value of 9,223,372,036,854,775,807. In the output for countMemoized 60
, we can see the count for "Q" approaching this value:
Int64.MaxValue: 9,223,372,036,854,775,807
Q: 1,529,516,899,017,849,463
E: 770,262,653,733,359,560
W: 6,063,456,462,484,930
We can also opt-in to using checked arithmetic by using the Checked.(+)
addition operator instead of the regular +
operator in Counter.incr
and Counter.add
. This turns on runtime overflow checks to detect when an operation would silently overflow and throws an exception instead. After making this change, we can clearly see the overflow occurring:
> countMemoized 60;;
val it: (char * int64) list =
[('E', 770262653733359560L); ('Q', 1529516899017849463L);
('W', 6063456462484930L)]
> countMemoized 70;;
System.OverflowException: Arithmetic operation resulted in an overflow.
at FSI_0002.ExtendedPolymerization.Counter.addKey@67[t](Dictionary`2 left, Dictionary`2 right, t k, Dictionary`2 dest) in C:\...\ExtendedPolymerization.fsx:line 68
at FSI_0002.ExtendedPolymerization.Counter.add@79.Invoke(t k) in C:\...\ExtendedPolymerization.fsx:line 79
at Microsoft.FSharp.Collections.SetTreeModule.iter[T](FSharpFunc`2 f, SetTree`1 t) in D:\a\_work\1\s\src\FSharp.Core\set.fs:line 279
at FSI_0002.ExtendedPolymerization.Counter.add[t](Dictionary`2 left, Dictionary`2 right) in C:\...\ExtendedPolymerization.fsx:line 79
at FSI_0002.ExtendedPolymerization.Memoized.countForPair(Dictionary`2 cache, FSharpMap`2 rules, Char leftChar, Char rightChar, Int32 iterations) in C:\...\ExtendedPolymerization.fsx:line 165
... removed repeated lines
at FSI_0002.ExtendedPolymerization.Memoized.countForPair(Dictionary`2 cache, FSharpMap`2 rules, Char leftChar, Char rightChar, Int32 iterations) in C:\...\ExtendedPolymerization.fsx:line 164
at FSI_0002.ExtendedPolymerization.Memoized.countWithOverlap@184-1.Invoke(Tuple`2 tupledArg) in C:\...\ExtendedPolymerization.fsx:line 184
at FSI_0002.ExtendedPolymerization.Memoized.countAllCharacters(FSharpMap`2 rules, FSharpList`1 template, Int32 iterations) in C:\...\ExtendedPolymerization.fsx:line 180
at FSI_0002.ExtendedPolymerization.countMemoized@235.Invoke(FSharpMap`2 rules, FSharpList`1 template, Int32 iterations)
at FSI_0002.ExtendedPolymerization.countWithFunction[a,b](FSharpFunc`2 solver, a n) in C:\...\ExtendedPolymerization.fsx:line 227
at FSI_0002.ExtendedPolymerization.countMemoized@235-1.Invoke(Int32 n)
at <StartupCode$FSI_0005>.$FSI_0005.main@() in C:\...\stdin:line 4
Stopped due to error
Mitigating integer overflow depends on a particular program's purpose. In this case, it is obvious in hindsight that Counter
can use uint64
instead of int64
- it is never supposed to store a negative value, and can store values up to twice as large in the same number of bits. If we were required to handle arbitrarily large numbers of iterations, perhaps we could use something like BigInteger instead. But for this Advent of Code problem, we know that the highest number of iterations we must handle is only 40, and we don't need to make any changes at all to handle that (unless we wanted to).
Next, we will test an alternative strategy for reducing the runtime.
- Divide and Conquer
- Memoization
- Dynamic Programming
- Benchmarking & Analysis (in progress)
https://en.wikipedia.org/wiki/Divide-and-conquer_algorithm ↩︎
On .NET - this can vary by language, runtime, or compiler options ↩︎