A Concurrency-safe, Iterable Binary Search Tree
In today’s article, we will implement a simple, unbalanced, concurrency-safe binary search tree in Golang. We will use some nice patterns along the way that are worth knowing.
I assume that the valued reader has a basic understanding of the Go programming language and knows what a binary search tree is. What I want to show is how easy it is to create a concurrency-safe version of this basic data structure in Golang.
Let’s begin! The API for our binary search tree should offer the following actions:
-
Value(key)
returns a node’s associated data (the value) -
Upsert(key, value)
updates or inserts the value for a node identified by key -
Iter()
iterates over all nodes of the tree in order
The key shall be a string unique to the instance of the tree. It is identifying an individual node. Regarding the value that is stored, we will be more agnostic and use an empty interface, thus allowing all kinds of data to be stored in our tree. From the data structure point of view, we do not care about the value. So why should we care about its type, then?
The Tree
For now, the tree is represented by a simple type that holds just the root node:
type BSTree struct {
root *node
}
The Nodes
For the nodes we will be using a simple struct with fields for the key, value, and pointers to child nodes:
type node struct {
key string
val interface{}
left *node
right *node
}
Retrieving the value of a node is, no surprise here, a recursive function. It leverages the fact that a binary search tree, per definition, is sorted. If we are searching for a key that is smaller then the key of the node we are looking at we have to proceed on to the left child. If the key is larger, we proceed on to the right child. If none of the above applies we have found the node we were looking for.
func (n *node) value(key string) (interface{}, error) {
if n == nil {
return nil, ErrorNotFound
}
if key < n.key {
return n.left.value(key)
} else if key > n.key {
return n.right.value(key)
}
return n.val, nil
}
Upsert
Upsert, the union of the operations of updating and inserting, works very similar. The only difference is, that if we can not find a node, we create the node.
func (n *node) upsert(key string, val interface{}) {
if key < n.key {
if n.left == nil {
n.left = &node{key: key, val: val}
} else {
n.left.upsert(key, val)
}
} else if key > n.key {
if n.right == nil {
n.right = &node{key: key, val: val}
} else {
n.right.upsert(key, val)
}
} else {
n.val = val
}
}
Concurrency-safe API using a Read-Write-Mutex
We will take two steps to ensure a concurrency-safe API:
- Protect the tree from being written to by multiple writers at the same time by introducing a Read-Write-Mutex (a lock)
- Create public wrapper methods honoring the mutex for the private method. For example
the method
Upsert()
is concurrency-safe while the methodupsert()
is not. SoUpsert()
will have to take care of locking and unlocking.
It is considered good style to place the synchronizing entity directly above the field that it is protecting:
type BSTree struct {
lock sync.RWMutex // <-- This one is protecting 'root'
root *node
}
Here the sync.RWMutex
is protecting the root node and all its descendants.
The public method Upsert()
of the BSTree
type is acquiring the
lock and then calling the private method upsert()
on the root node. The
releasing of the lock is deferred for convenience.
func (b *BSTree) Upsert(key string, val interface{}) {
b.lock.Lock()
defer b.lock.Unlock()
// if root node is empty, new node is root now
if b.root == nil {
b.root = &node{key: key, val: val}
} else {}
b.root.upsert(key, val)
}
}
Using this simple pattern we can keep the actual data changing methods free of synchronization artifacts. At the same time, we have all the locking and unlocking concentrated in the public methods. I like his pattern because it helps me gain a quick overview of the entry points for calling functions into my library code.
Iterating using a Channel
Iterating over a binary tree in order is relatively easy. We recursively walk through all the nodes and return the left child’s value, then the node’s value, and finally the right child’s value. Assuming we want to send the values into a channel, the code might look like this:
func (n *node) iter(ch chan<- Item) {
if n == nil {
return
}
n.left.iter(ch)
ch <- Item{
Key: n.key,
Val: n.val,
}
n.right.iter(ch)
}
The channel is of type Item
which is a simple struct that holds the key and
the value of a node. It looks like this:
type Item struct {
Key string
Val interface{}
}
Now comes the tricky part. We want the iter()
method to be used by external
code in a concurrency-safe way. This means we can allow one or more readers, each one
having its own channel for Item
retrieval, but no writer until the last
reading channel is closed. The Read-Write synchronization is taken care of by the mutex
already. Nice! However, parallel, non-blocking reading through multiple channels is not
that easy. It asks for a smart use of channels and goroutines.
Let’s start simple: A reader calling Iter()
shall receive a read-only channel.
Obviously, we have to create this channel at the beginning of the method and return the
channel when the method ends. This is easy:
func (b *BSTree) Iter() <-chan Item {
ch := make(chan Item)
b.lock.RLock()
// here be dragons
return ch
}
But wait, don’t we have to close the channel eventually? And how do we keep the program from blocking, given that the reader cannot receive from the channel before we have returned it? Furthermore, we can not (continuously) send to the channel until the reader has received from the channel, thus freeing up slots for us to write into. And when do we unlock the mutex?
The solution is to delegate all these tasks to a goroutine. The goroutine has access to the
context from which it was called. This allows us to call the iter()
method on
the root node, unlock the mutex, and also close the channel all inside the goroutine. And
suddenly, all we need is taken care of and we can return the channel we just created.
func (b *BSTree) Iter() <-chan Item {
ch := make(chan Item)
b.lock.RLock()
go func() {
b.root.iter(ch)
b.lock.RUnlock()
close(ch)
}()
return ch
}
It is OK to be confused by the code at first. A lot is happening here and the pattern might not be clear from the beginning. Feel free to take your time to grasp it.
Conclusion
Implementing a binary search tree in Golang is straightforward and does not require deep language knowledge. By adding synchronizing tools like mutexes and channels, data structures can be made ready for showtime in concurrent programs.
Source
package bstree
import (
"fmt"
"sync"
)
// Item holds the key and value of a node to be returned by an iterator
type Item struct {
Key string
Val interface{}
}
type node struct {
key string
val interface{}
left *node
right *node
}
// BSTree represents a binary search tree
type BSTree struct {
lock sync.RWMutex
root *node
}
// ErrorNotFound is returned when a key is not in the binary search tree
var ErrorNotFound = fmt.Errorf("not found")
func (n *node) value(key string) (interface{}, error) {
if n == nil {
return nil, ErrorNotFound
}
if key < n.key {
return n.left.value(key)
} else if key > n.key {
return n.right.value(key)
}
return n.val, nil
}
// Value returns the data associated with a given key
func (b *BSTree) Value(key string) (interface{}, error) {
b.lock.RLock()
defer b.lock.RUnlock()
return b.root.value(key)
}
func (n *node) upsert(key string, val interface{}) {
if key < n.key {
if n.left == nil {
n.left = &node{key: key, val: val}
} else {
n.left.upsert(key, val)
}
} else if key > n.key {
if n.right == nil {
n.right = &node{key: key, val: val}
} else {
n.right.upsert(key, val)
}
} else {
n.val = val
}
}
// Upsert updates or inserts data associated to a given key
func (b *BSTree) Upsert(key string, val interface{}) {
b.lock.Lock()
defer b.lock.Unlock()
// if root node is empty, new node is root now
if b.root == nil {
b.root = &node{key: key, val: val}
} else {
b.root.upsert(key, val)
}
}
func (n *node) isLeaf() bool {
return !n.hasLeft() && !n.hasRight()
}
func (n *node) hasLeft() bool {
return n.left != nil
}
func (n *node) hasRight() bool {
return n.right != nil
}
func (n *node) min() *node {
for ; n.left != nil; n = n.left {
}
return n
}
func (n *node) delete(key string) (*node, error) {
var err error
if n == nil {
return nil, ErrorNotFound
}
if key < n.key {
n.left, err = n.left.delete(key)
return n, err
}
if key > n.key {
n.right, err = n.right.delete(key)
return n, err
}
// case 1: node is leaf node
if n.isLeaf() {
return nil, nil
}
// case 2a: node has left child only
if n.hasLeft() && !n.hasRight() {
return n.left, nil
}
// case 2b: node has right child only
if n.hasRight() && !n.hasLeft() {
return n.right, nil
}
// case 3: node has two children
min := n.right.min()
n.key = min.key
n.val = min.val
n.right, err = n.right.delete(min.key)
return n, err
}
// Delete removes a key and associated data from a binsary search tree
func (b *BSTree) Delete(k string) error {
var err error
b.lock.RLock()
b.root, err = b.root.delete(k)
b.lock.RUnlock()
return err
}
func (n *node) iter(ch chan<- Item) {
if n == nil {
return
}
n.left.iter(ch)
ch <- Item{
Key: n.key,
Val: n.val,
}
n.right.iter(ch)
}
// Iter provides an iterator to walk through the binary search tree
func (b *BSTree) Iter() <-chan Item {
ch := make(chan Item)
b.lock.RLock()
go func() {
b.root.iter(ch)
b.lock.RUnlock()
close(ch)
}()
return ch
}