A Concurrency-safe, Iterable Binary Search Tree

In today’s article, we will implement a simple, unbalanced, concurrency-safe binary search tree in Golang. We will use some nice patterns along the way that are worth knowing.

I assume that the valued reader has a basic understanding of the Go programming language and knows what a binary search tree is. What I want to show is how easy it is to create a concurrency-safe version of this basic data structure in Golang.

Let’s begin! The API for our binary search tree should offer the following actions:

The key shall be a string unique to the instance of the tree. It is identifying an individual node. Regarding the value that is stored, we will be more agnostic and use an empty interface, thus allowing all kinds of data to be stored in our tree. From the data structure point of view, we do not care about the value. So why should we care about its type, then?

The Tree

For now, the tree is represented by a simple type that holds just the root node:

type BSTree struct {
    root *node
}

The Nodes

For the nodes we will be using a simple struct with fields for the key, value, and pointers to child nodes:

type node struct {
  key   string
  val   interface{}
  left  *node
  right *node
}

Retrieving the value of a node is, no surprise here, a recursive function. It leverages the fact that a binary search tree, per definition, is sorted. If we are searching for a key that is smaller then the key of the node we are looking at we have to proceed on to the left child. If the key is larger, we proceed on to the right child. If none of the above applies we have found the node we were looking for.

func (n *node) value(key string) (interface{}, error) {
    if n == nil {
        return nil, ErrorNotFound
    }
    if key < n.key {
        return n.left.value(key)
    } else if key > n.key {
        return n.right.value(key)
    }
    return n.val, nil
}

Upsert

Upsert, the union of the operations of updating and inserting, works very similar. The only difference is, that if we can not find a node, we create the node.

func (n *node) upsert(key string, val interface{}) {
    if key < n.key {
        if n.left == nil {
            n.left = &node{key: key, val: val}
        } else {
            n.left.upsert(key, val)
        }
    } else if key > n.key {
        if n.right == nil {
            n.right = &node{key: key, val: val}
        } else {
            n.right.upsert(key, val)
        }
    } else {
        n.val = val
    }
}

Concurrency-safe API using a Read-Write-Mutex

We will take two steps to ensure a concurrency-safe API:

It is considered good style to place the synchronizing entity directly above the field that it is protecting:

type BSTree struct {
    lock sync.RWMutex // <-- This one is protecting 'root'
    root *node
}

Here the sync.RWMutex is protecting the root node and all its descendants.

The public method Upsert() of the BSTree type is acquiring the lock and then calling the private method upsert() on the root node. The releasing of the lock is deferred for convenience.

func (b *BSTree) Upsert(key string, val interface{}) {
    b.lock.Lock()
    defer b.lock.Unlock()
    // if root node is empty, new node is root now
    if b.root == nil {
        b.root = &node{key: key, val: val}
    } else {}
        b.root.upsert(key, val)
    }
}

Using this simple pattern we can keep the actual data changing methods free of synchronization artifacts. At the same time, we have all the locking and unlocking concentrated in the public methods. I like his pattern because it helps me gain a quick overview of the entry points for calling functions into my library code.

Iterating using a Channel

Iterating over a binary tree in order is relatively easy. We recursively walk through all the nodes and return the left child’s value, then the node’s value, and finally the right child’s value. Assuming we want to send the values into a channel, the code might look like this:

func (n *node) iter(ch chan<- Item) {
    if n == nil {
        return
    }
    n.left.iter(ch)
    ch <- Item{
        Key: n.key,
        Val: n.val,
    }
    n.right.iter(ch)
}

The channel is of type Item which is a simple struct that holds the key and the value of a node. It looks like this:

type Item struct {
    Key string
    Val interface{}
}

Now comes the tricky part. We want the iter() method to be used by external code in a concurrency-safe way. This means we can allow one or more readers, each one having its own channel for Item retrieval, but no writer until the last reading channel is closed. The Read-Write synchronization is taken care of by the mutex already. Nice! However, parallel, non-blocking reading through multiple channels is not that easy. It asks for a smart use of channels and goroutines.

Let’s start simple: A reader calling Iter() shall receive a read-only channel. Obviously, we have to create this channel at the beginning of the method and return the channel when the method ends. This is easy:

func (b *BSTree) Iter() <-chan Item {
    ch := make(chan Item)
    b.lock.RLock()
    // here be dragons
    return ch
}

But wait, don’t we have to close the channel eventually? And how do we keep the program from blocking, given that the reader cannot receive from the channel before we have returned it? Furthermore, we can not (continuously) send to the channel until the reader has received from the channel, thus freeing up slots for us to write into. And when do we unlock the mutex?

The solution is to delegate all these tasks to a goroutine. The goroutine has access to the context from which it was called. This allows us to call the iter() method on the root node, unlock the mutex, and also close the channel all inside the goroutine. And suddenly, all we need is taken care of and we can return the channel we just created.

func (b *BSTree) Iter() <-chan Item {
    ch := make(chan Item)
    b.lock.RLock()
    go func() {
        b.root.iter(ch)
        b.lock.RUnlock()
        close(ch)
    }()
    return ch
}

It is OK to be confused by the code at first. A lot is happening here and the pattern might not be clear from the beginning. Feel free to take your time to grasp it.

Conclusion

Implementing a binary search tree in Golang is straightforward and does not require deep language knowledge. By adding synchronizing tools like mutexes and channels, data structures can be made ready for showtime in concurrent programs.

Source

package bstree

import (
    "fmt"
    "sync"
)

// Item holds the key and value of a node to be returned by an iterator
type Item struct {
    Key string
    Val interface{}
}

type node struct {
    key   string
    val   interface{}
    left  *node
    right *node
}

// BSTree represents a binary search tree
type BSTree struct {
    lock sync.RWMutex
    root *node
}

// ErrorNotFound is returned when a key is not in the binary search tree
var ErrorNotFound = fmt.Errorf("not found")

func (n *node) value(key string) (interface{}, error) {
    if n == nil {
        return nil, ErrorNotFound
    }
    if key < n.key {
        return n.left.value(key)
    } else if key > n.key {
        return n.right.value(key)
    }
    return n.val, nil
}

// Value returns the data associated with a given key
func (b *BSTree) Value(key string) (interface{}, error) {
    b.lock.RLock()
    defer b.lock.RUnlock()
    return b.root.value(key)
}

func (n *node) upsert(key string, val interface{}) {
    if key < n.key {
        if n.left == nil {
            n.left = &node{key: key, val: val}
        } else {
            n.left.upsert(key, val)
        }
    } else if key > n.key {
        if n.right == nil {
            n.right = &node{key: key, val: val}
        } else {
            n.right.upsert(key, val)
        }
    } else {
        n.val = val
    }
}

// Upsert updates or inserts data associated to a given key
func (b *BSTree) Upsert(key string, val interface{}) {
    b.lock.Lock()
    defer b.lock.Unlock()
    // if root node is empty, new node is root now
    if b.root == nil {
        b.root = &node{key: key, val: val}
    } else {
        b.root.upsert(key, val)
    }
}

func (n *node) isLeaf() bool {
    return !n.hasLeft() && !n.hasRight()
}

func (n *node) hasLeft() bool {
    return n.left != nil
}

func (n *node) hasRight() bool {
    return n.right != nil
}

func (n *node) min() *node {
    for ; n.left != nil; n = n.left {
    }
    return n
}

func (n *node) delete(key string) (*node, error) {
    var err error
    if n == nil {
        return nil, ErrorNotFound
    }
    if key < n.key {
        n.left, err = n.left.delete(key)
        return n, err
    }
    if key > n.key {
        n.right, err = n.right.delete(key)
        return n, err
    }
    // case 1: node is leaf node
    if n.isLeaf() {
        return nil, nil
    }
    // case 2a: node has left child only
    if n.hasLeft() && !n.hasRight() {
        return n.left, nil
    }
    // case 2b: node has right child only
    if n.hasRight() && !n.hasLeft() {
        return n.right, nil
    }
    // case 3: node has two children
    min := n.right.min()
    n.key = min.key
    n.val = min.val
    n.right, err = n.right.delete(min.key)
    return n, err
}

// Delete removes a key and associated data from a binsary search tree
func (b *BSTree) Delete(k string) error {
    var err error
    b.lock.RLock()
    b.root, err = b.root.delete(k)
    b.lock.RUnlock()
    return err
}

func (n *node) iter(ch chan<- Item) {
    if n == nil {
        return
    }
    n.left.iter(ch)
    ch <- Item{
        Key: n.key,
        Val: n.val,
    }
    n.right.iter(ch)
}

// Iter provides an iterator to walk through the binary search tree
func (b *BSTree) Iter() <-chan Item {
    ch := make(chan Item)
    b.lock.RLock()
    go func() {
        b.root.iter(ch)
        b.lock.RUnlock()
        close(ch)
    }()
    return ch
}