A couple of months ago I finally had to admit I wasn't smart enough to solve a few of the levels in Snakebird, a puzzle game. The only way to salvage some pride was to write a solver, and pretend that writing a program to do the solving is basically as good as having solved the problem myself. The C++ code for the resulting program is on Github. Most of what's discussed in the post is implemented in search.h and compress.h. This post deals mainly with optimizing a breadth-first search that's estimated to use 50-100GB of memory to run on a memory budget of 4GB.
There will be a follow up post that deals with the specifics of the game. For this post, all you need to know is that that I could not see any good alternatives to the brute force approach, since none of the usual tricks worked. There are a lot of states since there are multiple movable or pushable objects, and the shape of some of them matters and changes during the game. There were no viable conservative heuristics for algorithms like A* to narrow down the search space. The search graph was directed and implicit, so searching both forward and backward simultaneously was not possible. And a single move could cause the state to change in a lot of unrelated ways, so nothing like Zobrist hashing was going to be viable.
A back of the envelope calculation suggested that the biggest puzzle was going to have on the order of 10 billion states after eliminating all symmetries. Even after packing the state representation as tightly as possible, the state size was on the order of 8-10 bytes depending on the puzzle. 100GB of memory would be trivial at work, but this was my home machine with 16GB of RAM. And since Chrome needs 12GB of that, my actual memory budget was more like 4GB. Anything in excess of that would have to go to disk (the spinning rust kind).
How do we fit 100GB of data into 4GB of RAM? Either a) the states would need to be compressed to 1/20th of their original already optimized size, b) the algorithm would need to be able to efficiently page state to disk and back, c) a combination of the above, or d) I should buy more RAM or rent a big VM for a few days. Option D was out of the question due to being boring. Options A and C seemed out of the question after a proof of concept with gzip: a 50MB blob of states compressed to about 35MB. That's about 7 bytes per state, while my budget was more like 0.4 bytes per state. So option B it was, even though a breadth-first search looks pretty hostile to secondary storage.
Table of contents
This is a somewhat long post, so here's a brief overview of the sections ahead:
- A textbook BFS - What's the normal formulation of breadth-first search like, and why is it not suitable for storing parts of the state on disk?
- A sort + merge BFS - Changing the algorithm to efficiently do deduplications in batches.
- Compression - Reducing the memory use by 100x with a combination of off-the-shelf and custom compression.
- Oh no, I've cheated! - The first few sections glossed over something; it's not enough to know there is a solution, we need to know what the solution is. In this section the basic algorithm is updated to carry around enough data to reconstruct a solution from the final state.
- Sort + merge with multiple outputs - Keeping more state totally negates the compression gains. The sort + merge algorithm needs to be updated to keep two outputs: one that compresses well used during the search, and another that's just used to reconstruct the solution after one is found.
- Swapping - Swapping on Linux sucks even more than I thought.
- Compressing new states before merging - So far the memory optimizations have just been concerned with the visited set. But it turns out that the list of newly generated states is much larger than one might think. This section shows a scheme for representing the new states more efficiently.
- Saving space on the parent states - Investigate some CPU/memory tradeoffs for reconstructing the solution at the end.
- What didn't or might not work - Some things that looked promising but I ended up reverting, and others that research suggested would work but my intuition said wouldn't for this case.
A textbook BFS
So what does a breadth-first search look like, and why would it be disk-unfriendly? Before this little project I'd only ever seen variants of the textbook formulation, something like this:
def bfs(graph, start, end): visited = {start} todo = [start] while todo: node = todo.pop_first() if node == end: return True for kid in adjacent(node): if kid not in visited: visited.add(kid) todo.push_back(kid) return False
As the program produces new candidate nodes, each node is checked against a hash table of already visited nodes. If it's already present in the hash table, we ignore the node. Otherwise it's added both to the queue and the hash table. Sometimes the 'visited' information is carried in the nodes rather than in a side-table; but that's a dodgy optimization to start with, and totally impossible when the graph is implicit rather than explicit.
Why is a hash table problematic? Because hash tables will tend to have a totally random memory access pattern. If they don't, it's a bad hash function and the hash table will probably perform terribly due to collisions. This random access pattern can cause performance issues even when the data fits in memory: an access to a huge hash table is pretty likely to cause both a cache and TLB miss. But if a significant chunk of the data is actually on disk rather than in memory? It'd be disastrous: something on the order of 10ms per lookup.
With 10G unique states wed be looking at about four months of waiting for disk IO just for the hash table accesses. That can't work; the problem absolutely needs to be transformed such that the program can process big batches of data in one go.
A sort + merge BFS
If we wanted to batch the data access as much as possible, what would be the maximum achievable coarseness? Since the program can't know which nodes to processes on depth layer N+1 before layer N has been fully processed, it seems obvious that we have to do our deduplication of states at least once per depth.
Dealing with a whole layer at one time allows ditching hash tables, and representing the visited set and the new states as sorted streams of some sort (e.g. file streams, arrays, lists). We can trivially find the new visited set with a set union on the streams, and equally trivially find the todo set with a set difference.
The two set operations can be combined to work on a single pass through both streams. Basically peek into both streams, process the smaller element, and then advance the stream that the element came from (or both streams if the elements at the head were equal). In either case, add the element to the new visited set. When advancing just the stream of new states, also add the element to the new todo set:
def bfs(graph, start, end): visited = Stream() todo = Stream() visited.add(start) todo.add(start) while True: new = [] for node in todo: if node == end: return True for kid in adjacent(node): new.push_back(kid) new_stream = Stream() for node in new.sorted().uniq(): new_stream.add(node) todo, visited = merge_sorted_streams(new_stream, visited) return False # Merges sorted streams new and visited. Return a sorted stream of # elements that were just present in new, and another sorted # stream containing the elements that were present in either or # both of new and visited. def merge_sorted_streams(new, visited): out_todo, out_visited = Stream(), Stream() while visited or new: if visited and new: if visited.peek() == new.peek(): out_visited.add(visited.pop()) new.pop() elif visited.peek() < new.peek(): out_visited.add(visited.pop()) elif visited.peek() > new.peek(): out_todo.add(new.peek()) out_visited.add(new.pop()) elif visited: out_visited.add(visited.pop()) elif new: out_todo.add(new.peek()) out_visited.add(new.pop()) return out_todo, out_visited
The data access pattern is now perfectly linear and predictable, there are no random accesses at all during the merge. Disk latency thus becomes irrelevant, and the only thing that matters is throughput.
What does the theoretical performance look like with the simplified data distribution of 100 depth levels and 100M states per depth? The average state will be both read and written 50 times. That's 10 bytes/state * 5G states * 50 = 2.5TB. My hard drive can supposedly read and write at a sustained 100MB/s, which would mean (2 * 2.5TB) / (100MB/s) =~ 50k/s =~ 13 hours spent on the IO. That's a couple of orders of magnitude better than the earlier four month estimate!
It's worth noting that this simplistic model is not considering the size of the newly generated states. Before the merge step, they need to be kept in-memory for the sorting + deduplication. We'll look closer at that in a later section.
Compression
In the introduction I mentioned that compressing the states didn't look very promising in the initial experiments, with a 30% compression ratio. But after the above algorithm change the states are now ordered. That should be a lot easier to compress.
To test this theory, I used zstd on a puzzle of 14.6M states, with each state being 8 bytes. After the sorting they compressed to an average of 1.4 bytes per state. That seems like a solid improvement. Not quite enough to run the whole program in memory, but it could plausibly cut the disk IO to just a couple of hours.
Is there any way to do better than a state of the art general purpose compression algorithm, if you know something about the structure of the data? Almost certainly. One good example is the PNG format. Technically the compression is just a standard Deflate pass. But rather than compress the raw image data, the image is first transformed using PNG filters. A PNG filter is basically a formula for predicting the value of a byte in the raw data from the value of the same byte on the previous row and/or the same byte of the previous pixel. For example the 'up' filter transforms each byte by subtracting the previous row's value from it during compression, and doing the inverse when decompressing. Given the kinds of images PNG is meant for, the result will probably mostly consist of zeroes or numbers close to zero. Deflate can compress these far better than the raw data.
Can we apply a similar idea to the state records of the BFS? Seems like it should be possible. Just like in PNGs, there's a fixed row size, and we'd expect adjacent rows to be very similar. The first tries with a subtraction/addition filter followed by zstd resulted in another 40% improvement in compression ratios: 0.87 bytes per state. The filtering operations are trivial, so this was basically free from a CPU consumption point of view.
It wasn't clear if one could do a lot better than that, or whether this was a practical limit. In image data there's a reasonable expectation of similarity between adjacent bytes of the same row. For the state data that's not true. But actually slightly more sophisticated filters could still improve on that number. The one I ended up using worked like this:
Let's assume we have adjacent rows R1 = [1, 2, 3, 4] and R2 = [1, 2, 6, 4]. When outputting R2, we compare each byte to the same byte on the previous row, with a 0 for match and 1 for mismatch: diff = [0, 0, 1, 0]. We then emit that bitmap encoded as a VarInt, followed by just the bytes that did not match the previous row. In this example, the two bytes '0b00000100 6'. This filter alone compressed the benchmark to 2.2 bytes / state. But combining this filter + zstd got it down to 0.42 bytes / state. Or to put it another way, that's 3.36 bits per state, which is just a little bit over what the back of the envelope calculation suggested was needed to fit in RAM.
In practice the compression ratios improve as the sorted sets get more dense. Once the search gets to a point where memory starts getting an issue, the compression ratios can get a lot better than that. The largest problem turned out to have 4.6G distinct visited states in the end. These states took 405MB when sorted and compressed with the above scheme. That's 0.7 bits per state. The compression and decompression end up taking about 25% of the program's CPU time, but that seems like a great tradeoff for cutting memory use to 1/100th.
The filter above does feel a bit wasteful due to the VarInt header on every row. It seems like it should be easy to improve on it with very little extra cost in CPU or complexity. I tried a bunch of other variants that transposed the data to a column-major order, or wrote the bitmasks in bigger blocks, etc. These variants invariably got much better compression ratio by themselves, but then didn't do as well when the output of the filter was compressed with zstd. It wasn't just due to some quirk of zstd either, the results were similar with gzip and bzip2. I don't have any great theories on why this particular encoding ended up compressing much better than the alternatives.
Another mystery is the compression ratio ended up far better when the data was sorted little-endian rather than big-endian. I initially thought it was due to the little-endian sort ending up with more leading zeros on the VarInt-encoded bitmask. But this difference persisted even for filters that didn't have such dependencies.
(There's a lot of research on compressing sorted sets of integers, since they're a basic building block of search engines. I didn't find a lot on compressing sorted fixed-size records though, and didn't want to start jumping through the hoops of representing my data as arbitrary precision integers.q)
Oh no, I've cheated!
You might have noticed that the above pseudocode implementations of BFS were only returning a boolean for solution found / not found. That's not very useful. For most purposes you need to be able to produce a list of the exact steps of the solution, not just state that a solution exists.
On the surface the solution is easy. Rather than collect sets of states, collect mappings from states to a parent state. Then after finding a solution, just trace back the list of parent states from the end to the start. For the hash table based solution, it'd be something like:
def bfs(graph, start, end): visited = {start: None} todo = [start] while todo: node = todo.pop_first() if node == end: return trace_solution(node, visited) for kid in adjacent(node): if kid not in visited: visited[kid] = node todo.push_back(kid) return None def trace_solution(state, visited): if state is None: return [] return trace_solution(start, visited[state]) + [state]
Unfortunately this will totally kill the compression gains from the last section; the core assumption was that adjacent rows would be very similar. That was true when we just looked at the states themselves. But there is no reason to believe that's going to be true for the parent states; they're effectively random data. Second, the sort + merge solution has to read and write back all seen states on each iteration. To maintain the state / parent state mapping, we'd also have to read and write all this badly compressing data to disk on each iteration.
Sort + merge with multiple outputs
The program only needs the state/parent mappings at the very end, when tracing back the solution. We can thus maintain two data structures in parallel. 'Visited' is still the set of visited states, and gets recomputed during the merge just like before. 'Parents' is a mostly sorted list of state/parent pairs, which doesn't get rewritten. Instead the new states + their parents get appended to 'parents' after each merge operation.
def bfs(graph, start, end): parents = Stream() visited = Stream() todo = Stream() parents.add((start, None)) visited.add(start) todo.add(start) while True: new = [] for node in todo: if node == end: return trace_solution(node, parents) for kid in adjacent(node): new.push_back(kid) new_stream = Stream() for node in new.sorted().uniq(): new_stream.add(node) todo, visited = merge_sorted_streams(new_stream, visited, parents) return None # Merges sorted streams new and visited. New contains pairs of # key + value (just the keys are compared), visited contains just # keys. # # Returns a sorted stream of keys that were just present in new, # another sorted stream containing the keys that were present in either or # both of new and visited. Also adds the keys + values to the parents # stream for keys that were only present in new. def merge_sorted_streams(new, visited, parents): out_todo, out_visited = Stream(), Stream() while visited or new: if visited and new: visited_head = visited.peek() new_head = new.peek()[0] if visited_head == new_head: out_visited.add(visited.pop()) new.pop() elif visited_head < new_head: out_visited.add(visited.pop()) elif visited_head > new_head: out_todo.add(new_head) out_visited.add(new_head) out_parents.add(new.pop()) elif visited: out_visited.add(visited.pop()) elif new: out_todo.add(new.peek()[0]) out_visited.add(new.peek()[0]) out_parents.add(new.pop()) return out_todo, out_visited
This gives us the best of both worlds from a runtime and working set perspective, but does mean using more secondary storage. A separate copy of the visited states grouped by depth turns out to also be useful later on for other reasons.
Swapping
Another detail ignored in the snippets of pseudocode is that there is no explicit code for disk IO, just an abstract interface Stream. The Stream might be a file stream or an in-memory array, but we've been ignoring that implementation detail. Instead the pseudocode is concerned with having a memory access pattern that would be disk friendly. In a perfect world that'd be enough, and the virtual memory subsystem of the OS would take care of the rest.
At least with Linux that doesn't seem to be the case. At one point (before the working set had been shrunk to fit in memory) I'd gotten the program to run in about 11 hours when the data was stored mostly on disk. I then switched the program to use anonymous pages instead of file-backed ones, and set up sufficient swap on the same disk. After three days the program had gotten a quarter of the way through, and was still getting slower over time. My optimistic estimate was that it'd finish in 20 days.
Just to be clear, this was exactly the same code and exactly the same access pattern. The only thing that changed was whether the memory was backed by an explicit on-disk file or by swap. It's pretty much axiomatic that swapping tends to totally destroy performance on Linux, whereas normal file IO doesn't. I'd always assumed it was due to programs having the gall to treat RAM as something to be randomly accessed. But that wasn't the case here.
Turns out that file-backed and anonymous pages are not treated identically by the VM subsystem after all. They're kept in separate LRU caches with different expiration policies, and they also appear to have different readahead / prefetching properties.
So now I know: Linux swapping will probably not work well even under optimal circumstances. If parts of the address space are likely to be paged out for a while, it's better to arrange manually for the to be file-backed than to trust swap. I did it by implementing a custom vector class that started off as a purely in-memory implementation, and after a size threshold is exceeded switches to mmap on an unlinked temporary file.
Compressing new states before merging
In the simplified performance model the assumption was that there would be 100M new states per depth. That turned out not to be too far off reality (the most difficult puzzle peaked at about 150M unique new states from one depth layer). But it's also not the right thing to measure; the working set before the merge isn't related to just the unique states, but all the states that were output for this iteration. This measure peaks at 880M output states / depth. These 880M states a) need to be accessed with a random access pattern for the sorting, and b) can't be compressed efficiently due to not being sorted, c) need to be stored along with the parent state. That's a roughly 16GB working set.
The obvious solution would be to use some form of external sorting. Just write all the states to disk, do an external sort, do a deduplication, and then execute the merge just as before. This is the solution I went with first, but while it mostly solved problem A, it did nothing for B and C.
The alternative I ended up with was to collect the states into an in-memory array. If the array grows too large (e.g. more than 100M elements), it's sorted, deduplicated and compressed. This gives us a bunch of sorted runs of states, with no duplicates inside the run but potentially some between the runs. The code for merging the new and visited states is fundamentally the same; it's still based on walking through the streams in lockstep. The only change is that instead of walking through just the two streams, there's a separate stream for each of the sorted runs of new states.
The compression ratios for these 100M state runs are of course not quite as good as for compressing the set of all visited states. But even so, it cuts down both the working set and the disk IO requirements by a ton. There's a little bit of extra CPU from having to maintain a priority queue of streams, but it was still a great tradeoff.
Saving space on the parent states
At this point the vast majority of the space used by this program is spent on storing the parent states, so that we can reconstruct the solution after finding it. They are unlikely to compress well, but is there maybe a CPU/memory tradeoff to be made?
What we need is a mapping from a state S' at depth D+1 to its parent state S at depth D. If we could iterate all possible parent states of S', we could simply check if any of them appear at depth D in our visited set. (We've already produced the visited set grouped by depth as a convenient byproduct when outputting the state/parent mappings from merge). Unfortunately that doesn't work for this problem; it's simply too hard to generate all the possible states S given S'. It'd probably work just fine for many other search problems though.
If we can only generate the state transitions forward, not backward, how about just doing that then? Let's iterate through all the states at depth D, and see what output states they have. If some state produces S' as an output, we've found a workable S. The issue with the plan is that it increases the total CPU usage of the program by 50%. (Not 100%, since on average we find S after looking at half the states of depth D).
So I don't like either of the extremes, but at least there is a CPU/memory tradeoff available there. Is there maybe a more palatable option somewhere in the middle? What I ended up doing was to not store the pair (S', S), but instead (S', H(S)), where H is an 8 bit hash function. To find an S given S', again iterate through all the states at depth D. But before doing anything else, compute the same hash. If the output doesn't match H(S), this isn't the state we're looking for, and we can just skip it. This optimization means doing the expensive re-computation for just 1/256 states, which is a negligible CPU increase, while cutting down memory the memory spent for storing the parent states from 8-10 bytes to 1 byte.
What didn't or might not work
The previous sections go through a sequence of high level optimizations that worked. There were other things that I tried that didn't work, or that I found in the literature but decided would not actually work in this particular case. Here's a non-exhaustive list.
At one point I was not recomputing the full visited set at every iteration. Instead it was kept as multiple sorted runs, and those runs were occasionally compacted. The benefit was fewer disk writes and less CPU spent on compression. The downside was more code complexity and a worse compression ratio. I originally thought this design made sense since in my setup writes were more expensive than reads. But in the end the compression ratio was worse by a factor of 2. The tradeoffs are non-obvious, but in the end I reverted back to the simpler form.
There is a little bit of research done into executing huge breadth first searches for implicit graphs on secondary storage, a 2008 survey paper is a good starting point. As one might guess, the idea of doing the deduplication in a batch with sort+merge, on secondary store, isn't novel. The surprising part is that it was apparently only discovered in the 1993. That's pretty late! There are then some later proposals for secondary storage breadth first search that don't require a sorting step.
One of them was to map the states to integers, and to maintain an in-memory bitmap of the visited states. This is totally useless for my case, since the sizes of the encodable vs. actually reachable state spaces are so different. And I'm a bit doubtful about there being any interesting problems where this approach works.
The other viable sounding alternative is based on temporary hash tables. The visited states are stored unsorted in a file. Store the outputs from depth D in a hash table. Then iterate through the visited states, and look them up in the hash table. If the element is found in the hash table, remove it. After iterating through the whole file, only the non-duplicates remain. They can then be appended to the file, and used to initialize the todo list for the next iteration. If the number of outputs is so large that the hash table doesn't fit in memory, both the files and the hash tables can be partitioned using the same criteria (e.g. top bits of state), with each partition getting processed independently.
While there are benchmarks claiming the hash-based approach is roughly 30% faster than sort+merge, the benchmarks don't really seem to consider compression. I just don't see how giving up the compression gains could be worth it, so didn't experiment with these approaches at all.
The other relevant branch of research that seemed promising was database query optimization. The deduplication problem seems very much related to database joins, with exactly the same sort vs. hash dilemma. Obviously some of these findings should carry over to a search problem. The difference might be that the output of a database join is transient, while the outputs of a BFS deduplication persist for the rest of the computation. It feels like that changes the tradeoffs: it's not just about how to process one iteration most efficiently, it's also about having the outputs in the optimal format for the next iteration.
Conclusion
That concludes the things I learned from this project that seem generally applicable to other brute force search problems. These tricks combined to get the hardest puzzles of the game from an effective memory footprint of 50-100GB to 500MB, and degrading gracefully if the problem exceeds available memory and spills to disk. It is also 50% faster than a naive hash table based state deduplication even for puzzles that fit into memory.
The next post will deal with optimizing grid-based spatial puzzle games in general, as well as some issues specific just to this particular game.
In the meanwhile, Snakebird is available at least on Steam, Google Play, and the App Store. I recommend it for anyone interested in a very hard but fair puzzle game.
Thank you for the clear exposition!
I am following your "compressing new states before merging" paragraph for a puzzle where the size of the list of new states per step goes up to twice my total RAM (8G). I do not have much experimental information to know in how many pieces I should divide my list so that each piece fits easily in RAM (for instance, I first tried to split in 8 pieces of 1G, but it's still quite slow due to other tasks in my process and due to other processes).
On GNU/Linux, do you know how can I measure things such as the disk throughput used by my process and the effective part of the RAM used by my process (so I can deal with pieces of optimal size) ?