Weighty matters

Posted on May 15, 2011, under coding, general.

One of my favourite programming exercises is to write a text generator using Markov Chains. It usually takes a relatively small amount of code and is especially useful for learning new programming languages. A useful goal is to able to have a tool that ingests a large body of text, breaks it into N-tuples (where N is chosen by the operator), and then emits a random text of length L where that text uses those tuples at the same frequency that the original body did.

On its own, that task can produce some fun results, but it’s also a very repurpose-able technique. You can use this kind of statistical generation to simulate realistic network jitter (record some N-tuples of observed RTT with ping), or awesome simulated-user fuzz tests (record some N-tuples of observed user inputs). It’s surprising that it isn’t more common.

But when approaching these problems, from experience of working with newcomers, what seems to be a common first tripping point is how to do weighted selection at all. Put most simply, if we have a table of elements;

element weight
A 2
B 1
C 1

how do we write a function that will choose A about half the time, and B and C about a quarter each? It also turns out that this is a really interesting design problem. We can choose to implement a random solution, a non-random solution, a solution that runs in constant time, a solution that runs in linear time and a solution that runs in logarithmic time. This post is about those potential solutions.

Non-random solutions

When coming to the problem, the first thing to decide is whether we really want the selection to be random or not. One could imagine a function that tries to keep track of previous selections, for example;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
elements = [ { 'name' : 'A' , 'weight' : 2 },
             { 'name' : 'B' , 'weight' : 1 },
             { 'name' : 'C' , 'weight' : 1 } ]

def select_element(elements, count):
  total_weight = 0
  for element in elements:
    total_weight += element['weight']
    if count < total_weight:
         return  ( element ,
                   (count + 1) % len(elements) )

# select some elements
c = 0
(i , c) = select_element(elements, c)
(j , c) = select_element(elements, c)
(k , c) = select_element(elements, c)

This function uses the “count” parameter to remember where it left off last time, and runs in time proportionate to the total number of elements. Used correctly, the first two times it’s called this function will return A, followed by B, followed by C, followed by A twice again and so on. Hopefully that’s obvious.

There’s a very simple optimisation possible to make it run in constant time, just sacrifice memory proportionate to the total sum of weights;

1
2
3
4
5
6
7
8
elements = [ { 'name' : 'A' },
             { 'name' : 'A' },
             { 'name' : 'B' },
             { 'name' : 'C' } ]

def select_element(elements, count):
  return  ( element[count] ,
            (count + 1) % len(elements) )

This basic approach is surprisingly common, it's how a lot of networking devices implement weighted path selection.

Random solutions

But if we're feeding into some kind of simulation, or a fuzz-test, then randomised selection is probably a better thing. Luckily, there are at least 3 ways to do it. The first approach is to use the same space-inefficient approach our "optimised" non-random selection did. Flatten the list of elements into an array and just randomly jump into it;

1
2
3
4
5
6
7
8
9
import random

elements = [ { 'name' : 'A' },
             { 'name' : 'A' },
             { 'name' : 'B' },
             { 'name' : 'C' } ]

def select_element(elements):
  return random.choice(elements)

As before, this runs in constant time, but it's easy to see how this could get really ugly if there's some weights that are more like;

element weight
A 998
B 1
C 1

Unfortunately, common datasets can follow distributions that are just like this (Zipfian, for example). It would take an unreasonably large amount of space to store all of the words in the English language, proportionate to their frequencies, using this method.

Reservoir sampling

But luckily we have a way to avoid all of this space, and it's got some other useful utilities too. It's called reservoir sampling, and it's pretty magical. Reservoir sampling is a form of statistical sampling that lets you choose a limited number of samples from an arbitrarily large stream of data, without knowing how big it is in advance. It doesn't seem related, but it is.

Imagine you have a stream of data, and it looks something like;

"Banana" , "Car" , "Bus", "Apple", "Orange" , "Banana" , ....

and you want to choose a sample of 3 events. All events in the stream should have equal probability of making it into the sample. A better real-world example is collecting a sample of 1,000 web server requests over the last minute, but you get the idea.

What a reservoir sample does is simple;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def reservoir_sample(events, size):
  events_observed = 0
  sample = [ None ] * size

  for event in events:
   if events_observed >= len(sample):
     r = random.randint(0, events_observed):
   else:
     r = events_observed

   if r < len(sample):
     sample[r] = event    
   events_observed += 1

  return sample

and how this works is pretty cool. The first 3 events have a 100% likelihood of being sampled, r will be 0, 1, 2 in their case. The interesting part is after that. The fourth element has a 3/4 probability of being selected. So that's pretty simple.

But consider the likelihood of an element already in the sample of "staying". For any given element, there are two ways of "staying". One is that the fourth element is not chosen (1/4 probability) , another is that the fourth element is chosen (3/4 probability) but that a different element is selected (2/3) to be replaced. If you do the math;

1/4 + (3/4 * 2/3)   =>
1/4 + 6/12          =>
3/12 + 6/12         =>
9/12                =>
3/4

We see that any given element has a 3/4 likelihood of staying. Which is exactly what we want. There have been four elements observed, and all of them have had a 3/4 probability of being in our sample. Let's extend this an iteration, and see what happens at element 5.

When we get to this element, it has a 3/5 chance of being selected (r is a random number from the set 0,1,2,3,4 - if it is one of 0, 1 or 2 then the element will be chosen). We've already established that the previous elements had a 3/4 probability of being in the sample. Those elements that are in the sample again have two ways of staying, either the fifth element isn't chosen (2/5 likelihood) or it is, but a different element is replaced (3/5 * 2/3) . Again, let's do the math;

3/4 * (2/5 + (3/5 * 2/3))  =>
3/4 * (2/5 + 6/15)         =>
3/4 * (2/5 + 2/5)          =>
3/4 * 4/5                  =>
3/5

So, once again, all elements have a 3/5 likelihood of being in the sample. Every time another element is observed, the math gets a bit longer, but it stays the same - they always have equal probability of being in the final sample.

So what does this have to do with random selection? Well imagine our original weighted elements as a stream of events, and that our goal is to choose a sample of 1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import random

elements = [ { 'name' : 'A' },
             { 'name' : 'A' },
             { 'name' : 'B' },
             { 'name' : 'C' } ]

def select_element(elements):
  elements_observed = 0
  chosen_element = None
  for element in elements:
    r = randint(0, elements_observed):
    if r < 1:
        chosen_element = element
  return chosen_element

So far, this is actually worse than what we've had before. We're still using space in memory proportionate to the total sum of weights, and we're running in linear time. But now that we've structured things like this, we can use the weights to cheat;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
elements = [ { 'name' : 'A' , 'weight' : 2 },
             { 'name' : 'B' , 'weight' : 1 },
             { 'name' : 'C' , 'weight' : 1 } ]

def select_element(elements):
  total_weight = 0
  chosen_element = None
  for element in elements:
    total_weight += element['weight']
    r = randint(0, total_weight -1)
    if r < element['weight']:
         chosen_element = element

  return chosen_element

It's the same kind of self-correcting math as before, except now we can take bigger jumps rather than having to take steps of just 1 every time.

We now have a loop that runs in time proportionate to the total number of elements, and also uses memory proportionate to the total number of elements. That's pretty good, but there's another optimisation we can make.

Using a tree

If there are a large number of elements, it can still be a pain to have to iterate over them all just to select one. One solution to this problem is to compile the weighted elements into a weighted binary tree, so that we need only perform O(log) operations.

Let's take a larger weighted set;

element weight
A 1
B 2
C 3
D 1
E 1
F 1
G 1
H 1
I 1
J 1

which we can express in tree form;

where each node has the cumulative weight of its children.

Tree-solutions like this lend themselves easily to recursion;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import random

elements = [ {  'name' : 'A',
                'weight' : 10 },
             {  'name' : 'B',
                'weight' : 20 },
             {  'name' : 'C',
                'weight' : 30 },
             {  'name' : 'D',
                'weight' : 10 },
             {  'name' : 'E',
                'weight' : 10 },
             {  'name' : 'F',
                'weight' : 10 },
             {  'name' : 'G',
                'weight' : 10 },
             {  'name' : 'H',
                'weight' : 10 },
             {  'name' : 'I',
                'weight' : 10 },
             {  'name' : 'J',
                'weight' : 10 } ]

# Compile a set of weighted elements into a weighted tree
def compile_choices(elements):
    if len(elements) > 1:
        left =  elements[ : len(elements) / 2 ]
        right    =  elements[ (len(elements) / 2) : ]

        return [ { 'child': compile_choices(left) ,
                   'weight' : sum( map( lambda x : x['weight'] , left ) ) } ,
                 { 'child': compile_choices(right) ,
                   'weight' : sum( map( lambda x : x['weight'] , right ) ) } ]
    else:
        return elements

# Choose an element from a weighted tree
def tree_weighted_choice(tree):
    if len(tree) > 1:
        total_weight = tree[0]['weight'] + tree[1]['weight']
        if random.randint(0, total_weight - 1) < tree[0]['weight']:
            return tree_weighted_choice( tree[0]['child'] )
        else:
            return tree_weighted_choice( tree[1]['child'] )
    else:
        return tree[0]['name']

tree = compile_choices(elements)

And there we have it! A pretty good trade-off, we get sub-linear selection time with a relatively small overhead in memory.

A huffman(ish) tree

I should have included this is in the first version of this post, and Fergal was quick to point it out in the comments, but there is a further optimisation we can make. Instead of using a tree that is balanced purely by the number of nodes, as above, we can use a tree that is balanced by the sum of weights.

To re-use his example, imagine if you have elements that are weighted [ 10000, 1 , 1 , 1 ... ] (with a hundred ones). It doesn't make sense to bury the very weighty element deep in the tree, instead it should go near the root - so that on average we minimise the expected number of lookups.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def huffman_choices(elements):
    if len(elements) > 1:
        total_weight = sum( map( lambda x : x['weight'] , elements ) )
        sorted_elements = sorted(elements, key = lambda x : x['weight'] )[::-1]

        left = []
        right = []
        observed_weight = 0
        for element in sorted_elements:
            if observed_weight < total_weight / 2:
                left.append(element)
            else:
                right.append(element)

            observed_weight += element['weight']

        return [ { 'child': huffman_choices(left) ,
                   'weight' : sum( map( lambda x : x['weight'] , left ) ) } ,
                 { 'child': huffman_choices(right) ,
                   'weight' : sum( map( lambda x : x['weight'] , right ) ) } ]
    else:
        return elements

None of these techniques are novel, in fact they're all quite common and very old, yet for some reason they don't have broad awareness. Reservoir sampling - which on its own is incredibly useful - is only afforded a few sentences in Knuth.

But dealing with randomness and sampling is one of the inevitable complexities of programming. If you've never done it, it's worth taking the above and trying to write your own weighted markov chain generator. And then have even more fun thinking about how to test it.

Prime and Proper

Posted on September 2, 2010, under general.

A common software engineering pattern is the need to test elements for set membership.

Practical questions like “is this item in my cache” arise surprisingly often. Bloom Filters are a great mechanism for performing this test, and if you haven’t encountered them before – they’re well worth knowing about. But over the last year, I’ve been experimenting with a different method for these kinds of membership tests on enumerated sets, and thought it was worth sharing.

It’s definitely not the first time the technique has been used, Cian found some papers on its use. But it’s not a common pattern, which is a pity – because it’s incredibly simple, and lends itself very well to modern hardware.

First things first, some huge disclaimers;

  • The approach only works on enumerated sets. Each element in the problem space has to be assigned a unique counted number. Hashes won’t do.
  • The method is constrained by the size of product of the total number of elements and the maximum size of any set.

So, it’s very simple; Take each element in the problem space, assign it a unique prime number, and then represent sets as the products of those primes. This works because of the fundamental theorem of arithmetic.

Here’s a summary of operations;

Computing a set:
Compute the product of all of the elemental primes in the set.

s = a * b * c

Testing for element membership:
Test the modulus of the set number.

(s % a) == 0

Adding to the set:
Multiply the existing set product by the new element’s prime number.

s *= d

Removing from the set:
Divide the existing set product by the element’s prime number.

s /= a

Now at this point, you’re probably skeptical about the usefulness of the technique, given the constraints. Obviously other operations like unions and intersections between sets are possible, but they require factorisation – and so are not particularly efficient (though you’d be surprised how quickly they do run). But look at the benefits;

  • Unlike Bloom filters, there are zero false positives, and zero false negatives. The method is 100% precise.
  • Due to their use in cryptography; libraries, CPUs and other hardware have arisen that are highly efficient at computing the products and modulus of very large numbers.

As a case-in-point, let’s use the internet as an example. Suppose you want to model the internet as a system of interconnected autonomous systems and paths. A really common test when using such models is to determine whether or not a particular system is on a particular path.

If we take the internet as 100,000 autonomous systems, and the longest practical path on the internet as containing 20 such systems (longer paths exist, but are oddities – in real terms these numbers are both over-estimates) the largest product we would ever need to compute is smaller than 2410. That’s much smaller than even modern cryptographic primes themselves, and much much smaller than their products. There is a lot more room.

Actually, that was my first real-world use-case for this technique – modeling the internet in a few hundred lines of python, on a laptop – as a prototype. Surprisingly, a Macbook Air could perform well over one million such membership tests per second without skipping a beat – after reordering to assign the most popular AS numbers the lowest prime-numbers. And incidentally, if that kind of problem interests you – Cloudfront are hiring.

Now a question, for the information theorists; How is it that this method can seemingly perfectly encode N bits of information in fewer than N bits? For example it takes 2 bits to represent the number “2″ and 2-bits to represent the number “3″, yet to represent both through their product – “6″ it takes only 3 bits. But this isn’t like addition or some other operation – there really only is one way of factorising 6 to get back to “3″ and “2″. The information is seemingly encoded perfectly.