Practice: Build a recursive data structure

You are given an array of floating point numbers called numbers. These numbers lie between 0 and 1.

Write a function build_tree(numbers, left, right, max_in_leaf=5) that builds a "tree of bins" data structure, where

  • left_end is a lower bound on numbers
  • right_end is an upper bound on numbers
  • max_in_leaf is the largest number of numbers allowed in a leaf node of the tree

Have this function do the following:

  • If there are fewer numbers in numbers than max_in_leaf, return numbers unmodified as a 'leaf node'.
  • Otherwise, return a tuple of the form (left_child, pivot, right_child), where pivot is the average of left and right. left_child is the result of processing the part of numbers that is less than pivot through build_tree, and right_child is the same for the numbers larger than pivot.

Hints:

  • look up len() to find the length of numbers, or use numbers.shape[0]
In [2]:
import numpy as np

numbers = np.random.rand(100)
In [3]:
def build_tree(numbers, left, right, max_in_leaf=5):
    # ...
    pass
In [4]:
# Solution

def build_tree(numbers, left, right, max_in_leaf=5):
    if len(numbers) <= max_in_leaf:
        return numbers

    pivot = (left + right)/2
    return (build_tree(numbers[numbers < pivot], left, pivot, max_in_leaf),
            pivot,
            build_tree(numbers[numbers >= pivot], pivot, right, max_in_leaf))
In [5]:
tree = build_tree(numbers, 0, 1)
print(tree)
((((array([ 0.03155442,  0.04969038,  0.00203516,  0.01134467]), 0.0625, array([ 0.08795129,  0.08484712,  0.10400076])), 0.125, ((array([ 0.13288577,  0.1348917 ,  0.13717107,  0.13363111]), 0.15625, array([ 0.18361257,  0.16379185,  0.17935313])), 0.1875, array([ 0.18835925]))), 0.25, ((((array([ 0.25471829,  0.25316833,  0.25368532]), 0.265625, array([ 0.2753575 ,  0.273414  ,  0.27016936])), 0.28125, array([ 0.31082018,  0.29369432,  0.29940896,  0.30908776])), 0.3125, array([ 0.33571279,  0.37308478,  0.33152007,  0.35286179])), 0.375, (array([ 0.42385846,  0.4181284 ,  0.41651459,  0.40505667,  0.39770273]), 0.4375, (array([ 0.45339789,  0.45886606,  0.45242226,  0.46320172]), 0.46875, array([ 0.47478645,  0.47411047]))))), 0.5, (((((array([ 0.50499705,  0.50985016]), 0.515625, array([ 0.52229993,  0.51933766,  0.51886349,  0.52999165,  0.52412507])), 0.53125, array([ 0.55858432,  0.54768638,  0.55832187,  0.5325089 ,  0.55220628])), 0.5625, (array([ 0.57673859,  0.56823789,  0.5813294 ,  0.58937822]), 0.59375, array([ 0.6160641 ,  0.61934406,  0.602265  ,  0.6162048 ]))), 0.625, (array([ 0.65624928,  0.6413943 ,  0.67663038,  0.64642908]), 0.6875, (array([ 0.70809202,  0.68992577,  0.70079716,  0.70850778]), 0.71875, array([ 0.74158749,  0.73977694,  0.73835865])))), 0.75, (((array([ 0.77388481,  0.75894454,  0.75947687,  0.75107273,  0.75294742]), 0.78125, array([ 0.81109003,  0.79674588,  0.78627348,  0.80242775,  0.79621333])), 0.8125, array([ 0.85778992,  0.8153317 ,  0.8164691 ,  0.84316018])), 0.875, ((array([ 0.88278667,  0.88526777,  0.89802129]), 0.90625, array([ 0.92812108,  0.90822701,  0.91261268])), 0.9375, ((array([ 0.94584389,  0.95214282,  0.93878381,  0.94956855]), 0.953125, array([ 0.96443695,  0.96785209])), 0.96875, array([ 0.9805466 ,  0.97100946,  0.99383466]))))))
In [ ]: