Weighty matters

Posted on May 15, 2011, under coding, general.

One of my favourite programming exercises is to write a text generator using Markov Chains. It usually takes a relatively small amount of code and is especially useful for learning new programming languages. A useful goal is to able to have a tool that ingests a large body of text, breaks it into N-tuples (where N is chosen by the operator), and then emits a random text of length L where that text uses those tuples at the same frequency that the original body did.

On its own, that task can produce some fun results, but it’s also a very repurpose-able technique. You can use this kind of statistical generation to simulate realistic network jitter (record some N-tuples of observed RTT with ping), or awesome simulated-user fuzz tests (record some N-tuples of observed user inputs). It’s surprising that it isn’t more common.

But when approaching these problems, from experience of working with newcomers, what seems to be a common first tripping point is how to do weighted selection at all. Put most simply, if we have a table of elements;

element weight
A 2
B 1
C 1

how do we write a function that will choose A about half the time, and B and C about a quarter each? It also turns out that this is a really interesting design problem. We can choose to implement a random solution, a non-random solution, a solution that runs in constant time, a solution that runs in linear time and a solution that runs in logarithmic time. This post is about those potential solutions.

Non-random solutions

When coming to the problem, the first thing to decide is whether we really want the selection to be random or not. One could imagine a function that tries to keep track of previous selections, for example;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
elements = [ { 'name' : 'A' , 'weight' : 2 },
             { 'name' : 'B' , 'weight' : 1 },
             { 'name' : 'C' , 'weight' : 1 } ]

def select_element(elements, count):
  total_weight = 0
  for element in elements:
    total_weight += element['weight']
    if count < total_weight:
         return  ( element ,
                   (count + 1) % len(elements) )

# select some elements
c = 0
(i , c) = select_element(elements, c)
(j , c) = select_element(elements, c)
(k , c) = select_element(elements, c)

This function uses the “count” parameter to remember where it left off last time, and runs in time proportionate to the total number of elements. Used correctly, the first two times it’s called this function will return A, followed by B, followed by C, followed by A twice again and so on. Hopefully that’s obvious.

There’s a very simple optimisation possible to make it run in constant time, just sacrifice memory proportionate to the total sum of weights;

1
2
3
4
5
6
7
8
elements = [ { 'name' : 'A' },
             { 'name' : 'A' },
             { 'name' : 'B' },
             { 'name' : 'C' } ]

def select_element(elements, count):
  return  ( element[count] ,
            (count + 1) % len(elements) )

This basic approach is surprisingly common, it's how a lot of networking devices implement weighted path selection.

Random solutions

But if we're feeding into some kind of simulation, or a fuzz-test, then randomised selection is probably a better thing. Luckily, there are at least 3 ways to do it. The first approach is to use the same space-inefficient approach our "optimised" non-random selection did. Flatten the list of elements into an array and just randomly jump into it;

1
2
3
4
5
6
7
8
9
import random

elements = [ { 'name' : 'A' },
             { 'name' : 'A' },
             { 'name' : 'B' },
             { 'name' : 'C' } ]

def select_element(elements):
  return random.choice(elements)

As before, this runs in constant time, but it's easy to see how this could get really ugly if there's some weights that are more like;

element weight
A 998
B 1
C 1

Unfortunately, common datasets can follow distributions that are just like this (Zipfian, for example). It would take an unreasonably large amount of space to store all of the words in the English language, proportionate to their frequencies, using this method.

Reservoir sampling

But luckily we have a way to avoid all of this space, and it's got some other useful utilities too. It's called reservoir sampling, and it's pretty magical. Reservoir sampling is a form of statistical sampling that lets you choose a limited number of samples from an arbitrarily large stream of data, without knowing how big it is in advance. It doesn't seem related, but it is.

Imagine you have a stream of data, and it looks something like;

"Banana" , "Car" , "Bus", "Apple", "Orange" , "Banana" , ....

and you want to choose a sample of 3 events. All events in the stream should have equal probability of making it into the sample. A better real-world example is collecting a sample of 1,000 web server requests over the last minute, but you get the idea.

What a reservoir sample does is simple;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def reservoir_sample(events, size):
  events_observed = 0
  sample = [ None ] * size

  for event in events:
   if events_observed >= len(sample):
     r = random.randint(0, events_observed):
   else:
     r = events_observed

   if r < len(sample):
     sample[r] = event    
   events_observed += 1

  return sample

and how this works is pretty cool. The first 3 events have a 100% likelihood of being sampled, r will be 0, 1, 2 in their case. The interesting part is after that. The fourth element has a 3/4 probability of being selected. So that's pretty simple.

But consider the likelihood of an element already in the sample of "staying". For any given element, there are two ways of "staying". One is that the fourth element is not chosen (1/4 probability) , another is that the fourth element is chosen (3/4 probability) but that a different element is selected (2/3) to be replaced. If you do the math;

1/4 + (3/4 * 2/3)   =>
1/4 + 6/12          =>
3/12 + 6/12         =>
9/12                =>
3/4

We see that any given element has a 3/4 likelihood of staying. Which is exactly what we want. There have been four elements observed, and all of them have had a 3/4 probability of being in our sample. Let's extend this an iteration, and see what happens at element 5.

When we get to this element, it has a 3/5 chance of being selected (r is a random number from the set 0,1,2,3,4 - if it is one of 0, 1 or 2 then the element will be chosen). We've already established that the previous elements had a 3/4 probability of being in the sample. Those elements that are in the sample again have two ways of staying, either the fifth element isn't chosen (2/5 likelihood) or it is, but a different element is replaced (3/5 * 2/3) . Again, let's do the math;

3/4 * (2/5 + (3/5 * 2/3))  =>
3/4 * (2/5 + 6/15)         =>
3/4 * (2/5 + 2/5)          =>
3/4 * 4/5                  =>
3/5

So, once again, all elements have a 3/5 likelihood of being in the sample. Every time another element is observed, the math gets a bit longer, but it stays the same - they always have equal probability of being in the final sample.

So what does this have to do with random selection? Well imagine our original weighted elements as a stream of events, and that our goal is to choose a sample of 1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import random

elements = [ { 'name' : 'A' },
             { 'name' : 'A' },
             { 'name' : 'B' },
             { 'name' : 'C' } ]

def select_element(elements):
  elements_observed = 0
  chosen_element = None
  for element in elements:
    r = randint(0, elements_observed):
    if r < 1:
        chosen_element = element
  return chosen_element

So far, this is actually worse than what we've had before. We're still using space in memory proportionate to the total sum of weights, and we're running in linear time. But now that we've structured things like this, we can use the weights to cheat;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
elements = [ { 'name' : 'A' , 'weight' : 2 },
             { 'name' : 'B' , 'weight' : 1 },
             { 'name' : 'C' , 'weight' : 1 } ]

def select_element(elements):
  total_weight = 0
  chosen_element = None
  for element in elements:
    total_weight += element['weight']
    r = randint(0, total_weight -1)
    if r < element['weight']:
         chosen_element = element

  return chosen_element

It's the same kind of self-correcting math as before, except now we can take bigger jumps rather than having to take steps of just 1 every time.

We now have a loop that runs in time proportionate to the total number of elements, and also uses memory proportionate to the total number of elements. That's pretty good, but there's another optimisation we can make.

Using a tree

If there are a large number of elements, it can still be a pain to have to iterate over them all just to select one. One solution to this problem is to compile the weighted elements into a weighted binary tree, so that we need only perform O(log) operations.

Let's take a larger weighted set;

element weight
A 1
B 2
C 3
D 1
E 1
F 1
G 1
H 1
I 1
J 1

which we can express in tree form;

where each node has the cumulative weight of its children.

Tree-solutions like this lend themselves easily to recursion;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import random

elements = [ {  'name' : 'A',
                'weight' : 10 },
             {  'name' : 'B',
                'weight' : 20 },
             {  'name' : 'C',
                'weight' : 30 },
             {  'name' : 'D',
                'weight' : 10 },
             {  'name' : 'E',
                'weight' : 10 },
             {  'name' : 'F',
                'weight' : 10 },
             {  'name' : 'G',
                'weight' : 10 },
             {  'name' : 'H',
                'weight' : 10 },
             {  'name' : 'I',
                'weight' : 10 },
             {  'name' : 'J',
                'weight' : 10 } ]

# Compile a set of weighted elements into a weighted tree
def compile_choices(elements):
    if len(elements) > 1:
        left =  elements[ : len(elements) / 2 ]
        right    =  elements[ (len(elements) / 2) : ]

        return [ { 'child': compile_choices(left) ,
                   'weight' : sum( map( lambda x : x['weight'] , left ) ) } ,
                 { 'child': compile_choices(right) ,
                   'weight' : sum( map( lambda x : x['weight'] , right ) ) } ]
    else:
        return elements

# Choose an element from a weighted tree
def tree_weighted_choice(tree):
    if len(tree) > 1:
        total_weight = tree[0]['weight'] + tree[1]['weight']
        if random.randint(0, total_weight - 1) < tree[0]['weight']:
            return tree_weighted_choice( tree[0]['child'] )
        else:
            return tree_weighted_choice( tree[1]['child'] )
    else:
        return tree[0]['name']

tree = compile_choices(elements)

And there we have it! A pretty good trade-off, we get sub-linear selection time with a relatively small overhead in memory.

A huffman(ish) tree

I should have included this is in the first version of this post, and Fergal was quick to point it out in the comments, but there is a further optimisation we can make. Instead of using a tree that is balanced purely by the number of nodes, as above, we can use a tree that is balanced by the sum of weights.

To re-use his example, imagine if you have elements that are weighted [ 10000, 1 , 1 , 1 ... ] (with a hundred ones). It doesn't make sense to bury the very weighty element deep in the tree, instead it should go near the root - so that on average we minimise the expected number of lookups.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def huffman_choices(elements):
    if len(elements) > 1:
        total_weight = sum( map( lambda x : x['weight'] , elements ) )
        sorted_elements = sorted(elements, key = lambda x : x['weight'] )[::-1]

        left = []
        right = []
        observed_weight = 0
        for element in sorted_elements:
            if observed_weight < total_weight / 2:
                left.append(element)
            else:
                right.append(element)

            observed_weight += element['weight']

        return [ { 'child': huffman_choices(left) ,
                   'weight' : sum( map( lambda x : x['weight'] , left ) ) } ,
                 { 'child': huffman_choices(right) ,
                   'weight' : sum( map( lambda x : x['weight'] , right ) ) } ]
    else:
        return elements

None of these techniques are novel, in fact they're all quite common and very old, yet for some reason they don't have broad awareness. Reservoir sampling - which on its own is incredibly useful - is only afforded a few sentences in Knuth.

But dealing with randomness and sampling is one of the inevitable complexities of programming. If you've never done it, it's worth taking the above and trying to write your own weighted markov chain generator. And then have even more fun thinking about how to test it.

10 Replies to "Weighty matters"

gravatar

Fergal Daly  on May 15, 2011

I’d never noticed before that reservoir sampling is just a Fisher-Yates shuffle that shuffles the whole deck but only cares about the final position of cards 1-N.

compile_choices for the tree case doesn’t build an optimal binary search tree (a tree that minimises the expected number of comparisons for the given weights). Since the weights _are_ the keys, I think just sorting by weight first and then branching the tree so that the weights are balanced will give something very close (I think it will be a Huffman tree).

E.g. if the weights are {10000, 1, 1, …, 1} (100 1s in there), you definitely want the 10000 node at the top of the tree, then in 99% of cases you select an element with 1 comparison, instead of 7.

Building an optimal binary search tree is fun too.

gravatar

David Malone  on May 15, 2011

I guess if you know the weigths in advance, you can use Hufmann coding to generate the tree so you have the least depth to walk through on average?

The comments aren’t so far from that password paper that I sent you – did you get a look at Section 6? It’s right up this street (I’ve uploaded a version at http://arxiv.org/abs/1104.3722 )

Fergal – I never knew that Fisher-Yates and Algortihm-P were the same. You learn something new every day.

gravatar

pixelbeat  on May 15, 2011

@Fergal coincidentally yesterday I added an optimization to the knuth shuffle used in the `shuf` command, used when you only want the start of the “deck”. http://git.sv.gnu.org/gitweb/?p=coreutils.git;a=commitdiff;h=27873f1d

gravatar

colmmacc  on May 15, 2011

I like the Huffman tree approach – I’ve used that too, but they can be super expensive to build when you throw some seriously large data at it. One thing I like about the simpler tree here is that building it is easily parallelisable. For a dataset like tuple frequency in the english language that matters. Though with foreknowledge, a heap could be used to keep the input sorted.

gravatar

colmmacc  on May 15, 2011

David,

I read the paper, but didn’t twig the correlation. Now that I read it again, it makes more sense. I remember going to a lecture you gave in TCD a few years ago where you demonstrated some Markov generation using the bible as a source input. Do you still have that? It’d be interesting to see how it was tuned.

On thee passwords – if we’re willing to reject a password because it’s over-used, and hence leak some information, what about enforcing that all passwords in a system are unique? A large bloom filter could be used to do probabilistically if space is a concern.

gravatar

colmmacc  on May 15, 2011

I’ve updated the post to include a huffman tree. I think the optimal solution (for smallest expectation value of lookups) is to balance by the sum of weights. A huffman coding tree (where the most common element always gets the smallest available path) won’t be optimal because the tree will be too imbalanced in many cases.

gravatar

Fergal Daly  on May 15, 2011

Actually, you can avoid trees altogether, just build 2 lists, one with the observed_weight (from the huffman cpde) and the other with the corresponding element. Then calling python’s bisect.bisect on the observed_weights list will give you an index into the element list.

I’m being deliberately vague since there’s a lot of off-by-onery to be had here.

Anyway, unless you need to insert more elements into the middle of it or adjust the weights, this is nice and compact (and in python at least uses a a handy library).

Actually, you can update this after the fact by simply adding more entries. It doesn’t matter if they have the same values as entries in there already. Obvious you don’t want to do that too much since you lose the benefits of having coalescing all occurrences of each item into a single count but you could run a compaction every now and then if that was happening.

gravatar

colmmacc  on May 15, 2011

Fergal, is that any better than using a skip list?

gravatar

Fergal Daly  on May 15, 2011

Actually, now that I think a bit more about it, what I described is really only useful for the original tree version, not the huffman one since it wouldn’t put the heavy items near the root.

For something better, you could do as I said with the list lay it out so that the heaviest is at N/2, the next heaviest are at N/4 and 3N/4 but it’s getting messy now and you can’t append to it any more.

Either way, in any static tree versions, I think you want to use the cumulative weight as the key because you’re going to be generating random.uniform(0, total_weight) and then seeing where in the tree that leads you.

In the tree in your diagram, to select an element, I pick R from 1-13, if it’s 8 then I repeat with R – 5 on the right subtree. Using the cumulative weight as the key avoids this subtraction every time you branch right.

I have never used skip lists for anything so pinch of salt… Are you talking about the case where you want to do more updates afterwards? It seems like maybe an indexable skip list could be adapted. Usually links have widths but maybe elements could have widths too. So when you increment the count of an element that increment has to flow up to the root.

Then again, that seems like it could be done with a binary tree too (using the summed weights as in your diagram, not the cumulative ones).

gravatar

David Malone  on May 17, 2011

I still have the code that I used to generate the Bible travesty. It was a slightly clunky bit of perl that made good use of associative arrays, but really just used brute force. I’ll send it on to you. Metropolis–Hastings is a cool algorithm for sampling too though – I suspect if it was better known outside the physical simulation community, it would get more use.

You could enforce uniqueness of passwords using a bloom filter, but I suspect the cost in terms of number-of-retries could be quite high – I’ll see if I can figure out how to guess. If you have a look at the paper “Popularity is everything: a new approach to protecting passwords from statistical-guessing attacks” that we cite, it has a nice technique using min-count-sketches (a close relative of the Bloom filter) to stop a particular password from becoming too popular.

Leave a Comment