Trying to learn F#, by solving some programming puzzles. I don't want to add too many details about the problem as I don't want to spoil the fun for others.
Basically, the issue is to find all 4-uples { (i,j,k,l) | i ^ j ^ k ^ l != 0 } with no repetition (eg., (1,1,1,2) and (1,1,2,1) are the same and should be counted just once).
I have found a O(n^3) approach which works, please see countImperative(a,b,c,d) below. But I also tried to refactor the code as to get rid of the nested for loops. However, I could not do so without a significant performance penalty. It was my impression that F#'s syntactic sugar would allow a more concise style (using pipes and folds), letting the compiler do the heavy-lifting to produce comparably fast code (compared to my nested for loops). The big performance hit comes from the calculation of the partial2 sum.
Here's the code:
open System
open System.Diagnostics
open System.Collections
module quadruples =
[<EntryPoint>]
let main argv =
let input = "2000 2000 2000 2000"
let ordered = [ for x in input.Split([|' '|]) -> Convert.ToInt32(x) ] |> List.sort
let a,b,c,d = ordered.[0], ordered.[1], ordered.[2], ordered.[3]
let inner(a,b) = a * (a-1) / 2 + a * (b-a)
let sw = new Stopwatch()
sw.Start()
let partial1 = [ 1.. b ] |> List.fold (fun acc j -> acc + (int64 ((min a j) * inner(c-j+1, d-j+1)))) 0L
sw.Stop()
let elapsed1 = (sw.ElapsedMilliseconds |> double) / 1000.0
printfn "Partial1: %f s" elapsed1
sw.Restart()
let combinations = [ for i in 1..a do for j in i+1..b do yield (j,i^^^j) ]
let range = [ 1..c ]
let partial2 = combinations |> List.fold(fun acc (j,x) -> acc + (range |> List.skip(j-1) |> List.fold(fun acc k -> if k ^^^ x < k || k ^^^ x > d then acc + 1L else acc) 0L)) 0L
sw.Stop()
let elapsed2 = (sw.ElapsedMilliseconds |> double) / 1000.0
printfn "Partial2: %f s" elapsed2
printfn "Functional: %d, Elapsed: %f s" (partial1 + partial2) (elapsed1 + elapsed2)
// "imperative" approach
let countImperative(a,b,c,d) =
let mutable count = seq { 1..b } |> Seq.fold (fun acc j -> acc + (int64 ((min a j) * inner(c-j+1, d-j+1)))) 0L
for i in 1..a do
for j in i+1..b do
let x = i ^^^ j
for k in j..c do
let y = x ^^^ k
if y < k || y > d then
count <- count + 1L
count
sw.Restart();
let count = countImperative(a,b,c,d)
sw.Stop()
printfn "Imperative: %d, Elapsed: %f s" count ((sw.ElapsedMilliseconds |> double) / 1000.0)
0 // return an integer exit code
So my question was, if there is any way to speed up the code (specifically the calculation of partial2) while maintaining F#'s nice syntax.
First, in order to provide full disclosure, I want to point out that this is related to homework in a Machine Learning class. This question is not the homework question and instead is something I need to figure out in order to complete the bigger problem of creating an ID3 Decision Tree Algorithm.
I need to generate tree similar to the following when given a truth table
let learnedTree = Node(0,"A0", Node(2,"A2", Leaf(0), Leaf(1)), Node(1,"A1", Node(2,"A2", Leaf(0), Leaf(1)), Leaf(0)))
learnedTree is of type BinaryTree which I've defined as follows:
type BinaryTree =
| Leaf of int
| Node of int * string * BinaryTree * BinaryTree
ID3 algorithms take into account various equations to determine where to split the tree, and I've got all that figured out, I'm just having trouble creating the learned tree from my truth table. For example if I have the following table
A1 | A2 | A3 | Class
1 0 0 1
0 1 0 1
0 0 0 0
1 0 1 0
0 0 0 0
1 1 0 1
0 1 1 0
And I decide to split on attribute A1 I would end up with the following:
(A1 = 1) A1 (A1 = 0)
A2 | A3 | Class A2 | A3 | Class
0 0 1 1 0 1
0 1 0 0 0 0
1 0 1 0 0 0
0 1 1
Then I would split the left side and split the right side, and continue the recursive pattern until the leaf nodes are pure and I end up with a tree similar to the following based on the splitting.
let learnedTree = Node(0,"A0", Node(2,"A2", Leaf(0), Leaf(1)), Node(1,"A1", Node(2,"A2", Leaf(0), Leaf(1)), Leaf(0)))
Here is what I've kind of "hacked" together thus far, but I think I might be way off:
let rec createTree (listToSplit : list<list<float>>) index =
let leftSideSplit =
listToSplit |> List.choose (fun x -> if x.Item(index) = 1. then Some(x) else None)
let rightSideSplit =
listToSplit |> List.choose (fun x -> if x.Item(index) = 0. then Some(x) else None)
if leftSideSplit.Length > 0 then
let pureCheck = isListPure leftSideSplit
if pureCheck = 0 then
printfn "%s" "Pure left node class 0"
createTree leftSideSplit (index + 1)
else if pureCheck = 1 then
printfn "%s" "Pure left node class 1"
createTree leftSideSplit (index + 1)
else
printfn "%s - %A" "Recursing Left" leftSideSplit
createTree leftSideSplit (index + 1)
else printfn "%s" "Pure left node class 0"
Should I be using pattern matching instead? Any tips/ideas/help? Thanks a bunch!
Edit: I've since posted an implementation of ID3 on my blog at:
http://blogs.msdn.com/chrsmith
Hey Jim, I've been wanting to write a blog post implementing ID3 in F# for a while - thanks for giving me an execute. While this code doesn't implement the algorithm full (or correctly), it should be sufficient for getting you started.
In general you have the right approach - representing each branch as a discriminated union case is good. And like Brian said, List.partition is definitely a handy function. The trick to making this work correctly is all in determining the optimal attribute/value pair to split on - and to do that you'll need to calculate information gain via entropy, etc.
type Attribute = string
type Value = string
type Record =
{
Weather : string
Temperature : string
PlayTennis : bool
}
override this.ToString() =
sprintf
"{Weather = %s, Temp = %s, PlayTennis = %b}"
this.Weather
this.Temperature
this.PlayTennis
type Decision = Attribute * Value
type DecisionTreeNode =
| Branch of Decision * DecisionTreeNode * DecisionTreeNode
| Leaf of Record list
// ------------------------------------
// Splits a record list into an optimal split and the left / right branches.
// (This is where you use the entropy function to maxamize information gain.)
// Record list -> Decision * Record list * Record list
let bestSplit data =
// Just group by weather, then by temperature
let uniqueWeathers =
List.fold
(fun acc item -> Set.add item.Weather acc)
Set.empty
data
let uniqueTemperatures =
List.fold
(fun acc item -> Set.add item.Temperature acc)
Set.empty
data
if uniqueWeathers.Count = 1 then
let bestSplit = ("Temperature", uniqueTemperatures.MinimumElement)
let left, right =
List.partition
(fun item -> item.Temperature = uniqueTemperatures.MinimumElement)
data
(bestSplit, left, right)
else
let bestSplit = ("Weather", uniqueWeathers.MinimumElement)
let left, right =
List.partition
(fun item -> item.Weather = uniqueWeathers.MinimumElement)
data
(bestSplit, left, right)
let rec determineBranch data =
if List.length data < 4 then
Leaf(data)
else
// Use the entropy function to break the dataset on
// the category / value that best splits the data
let bestDecision, leftBranch, rightBranch = bestSplit data
Branch(
bestDecision,
determineBranch leftBranch,
determineBranch rightBranch)
// ------------------------------------
let rec printID3Result indent branch =
let padding = new System.String(' ', indent)
match branch with
| Leaf(data) ->
data |> List.iter (fun item -> printfn "%s%s" padding <| item.ToString())
| Branch(decision, lhs, rhs) ->
printfn "%sBranch predicate [%A]" padding decision
printfn "%sWhere predicate is true:" padding
printID3Result (indent + 4) lhs
printfn "%sWhere predicate is false:" padding
printID3Result (indent + 4) rhs
// ------------------------------------
let dataset =
[
{ Weather = "windy"; Temperature = "hot"; PlayTennis = false }
{ Weather = "windy"; Temperature = "cool"; PlayTennis = false }
{ Weather = "nice"; Temperature = "cool"; PlayTennis = true }
{ Weather = "nice"; Temperature = "cold"; PlayTennis = true }
{ Weather = "humid"; Temperature = "hot"; PlayTennis = false }
]
printfn "Given input list:"
dataset |> List.iter (printfn "%A")
printfn "ID3 split resulted in:"
let id3Result = determineBranch dataset
printID3Result 0 id3Result
You can use List.partition instead of your two List.choose calls.
http://research.microsoft.com/en-us/um/cambridge/projects/fsharp/manual/FSharp.Core/Microsoft.FSharp.Collections.List.html
(or now http://msdn.microsoft.com/en-us/library/ee353738(VS.100).aspx )
It isn't clear to me that pattern matching will buy you much here; the input type (list of lists) and processing (partitioning and 'pureness' check) doesn't really lend itself to that.
And of course when you finally get the 'end' (a pure list) you need to create a tree, and then presumably this function will create a Leaf when the input only has one 'side' and it's 'pure', but create a Node out of the left-side and right-side results for every other input. Maybe. I didn't quite grok the algorithm completely.
Hopefully that will help steer you a little bit. May be useful to draw up a few smaller sample inputs and outputs to help work out the various cases of the function body.
Thanks Brian & Chris! I was actually able to figure this out and I ended up with the following. This calculates the information gain for determining the best place to split. I'm sure there are probably better ways for me to arrive at this solution especially around the chosen data structures, but this is a start. I plan to refine things later.
#light
open System
let trainList =
[
[1.;0.;0.;1.;];
[0.;1.;0.;1.;];
[0.;0.;0.;0.;];
[1.;0.;1.;0.;];
[0.;0.;0.;0.;];
[1.;1.;0.;1.;];
[0.;1.;1.;0.;];
[1.;0.;0.;1.;];
[0.;0.;0.;0.;];
[1.;0.;0.;1.;];
]
type BinaryTree =
| Leaf of int
| Node of int * string * BinaryTree * BinaryTree
let entropyList nums =
let sumOfnums =
nums
|> Seq.sum
nums
|> Seq.map (fun x -> if x=0.00 then x else (-((x/sumOfnums) * Math.Log(x/sumOfnums, 2.))))
|> Seq.sum
let entropyBinaryList (dataListOfLists:list<list<float>>) =
let classList =
dataListOfLists
|> List.map (fun x -> x.Item(x.Length - 1))
let ListOfNo =
classList
|> List.choose (fun x -> if x = 0. then Some(x) else None)
let ListOfYes =
classList
|> List.choose (fun x -> if x = 1. then Some(x) else None)
let numberOfYes : float = float ListOfYes.Length
let numberOfNo : float = float ListOfNo.Length
let ListOfNumYesAndSumNo = [numberOfYes; numberOfNo]
entropyList ListOfNumYesAndSumNo
let conditionalEntropy (dataListOfLists:list<list<float>>) attributeNumber =
let NoAttributeList =
dataListOfLists
|> List.choose (fun x -> if x.Item(attributeNumber) = 0. then Some(x) else None)
let YesAttributeList =
dataListOfLists
|> List.choose (fun x -> if x.Item(attributeNumber) = 1. then Some(x) else None)
let numberOfYes : float = float YesAttributeList.Length
let numberOfNo : float = float NoAttributeList.Length
let noConditionalEntropy = (entropyBinaryList NoAttributeList) * (numberOfNo/(numberOfNo + numberOfYes))
let yesConditionalEntropy = (entropyBinaryList YesAttributeList) * (numberOfYes/(numberOfNo + numberOfYes))
[noConditionalEntropy; yesConditionalEntropy]
let findBestSplitIndex(listOfInstances : list<list<float>>) =
let IGList =
[0..(listOfInstances.Item(0).Length - 2)]
|> List.mapi (fun i x -> (i, (entropyBinaryList listOfInstances) - (List.sum (conditionalEntropy listOfInstances x))))
IGList
|> List.maxBy snd
|> fst
let isListPure (listToCheck : list<list<float>>) =
let splitList = listToCheck |> List.choose (fun x -> if x.Item(x.Length - 1) = 1. then Some(x) else None)
if splitList.Length = listToCheck.Length then 1
else if splitList.Length = 0 then 0
else -1
let rec createTree (listToSplit : list<list<float>>) =
let pureCheck = isListPure listToSplit
if pureCheck = 0 then
printfn "%s" "Pure - Leaf(0)"
else if pureCheck = 1 then
printfn "%s" "Pure - Leaf(1)"
else
printfn "%A - is not pure" listToSplit
if listToSplit.Length > 1 then // There are attributes we can split on
// Chose best place to split list
let splitIndex = findBestSplitIndex(listToSplit)
printfn "spliting at index %A" splitIndex
let leftSideSplit =
listToSplit |> List.choose (fun x -> if x.Item(splitIndex) = 1. then Some(x) else None)
let rightSideSplit =
listToSplit |> List.choose (fun x -> if x.Item(splitIndex) = 0. then Some(x) else None)
createTree leftSideSplit
createTree rightSideSplit
else
printfn "%s" "Not Pure, but can't split choose based on heuristics - Leaf(0 or 1)"
How would you convert type Node into an immutable tree?
This class implements a range tree that does not allow overlapping or adjacent ranges and instead joins them. For example if the root node is {min = 10; max = 20} then it's right child and all its grandchildren must have a min and max value greater than 21. The max value of a range must be greater than or equal to the min. I included a test function so you can run this as is and it will dump out any cases that fail.
I recommend starting with the Insert method to read this code.
module StackOverflowQuestion
open System
type Range =
{ min : int64; max : int64 }
with
override this.ToString() =
sprintf "(%d, %d)" this.min this.max
[<AllowNullLiteralAttribute>]
type Node(left:Node, right:Node, range:Range) =
let mutable left = left
let mutable right = right
let mutable range = range
// Symmetric to clean right
let rec cleanLeft(node : Node) =
if node.Left = null then
()
elif range.max < node.Left.Range.min - 1L then
cleanLeft(node.Left)
elif range.max <= node.Left.Range.max then
range <- {min = range.min; max = node.Left.Range.max}
node.Left <- node.Left.Right
else
node.Left <- node.Left.Right
cleanLeft(node)
// Clean right deals with merging when the node to merge with is not on the
// left outside of the tree. It travels right inside the tree looking for an
// overlapping node. If it finds one it merges the range and replaces the
// node with its left child thereby deleting it. If it finds a subset node
// it replaces it with its left child, checks it and continues looking right.
let rec cleanRight(node : Node) =
if node.Right = null then
()
elif range.min > node.Right.Range.max + 1L then
cleanRight(node.Right)
elif range.min >= node.Right.Range.min then
range <- {min = node.Right.Range.min; max = range.max}
node.Right <- node.Right.Left
else
node.Right <- node.Right.Left
cleanRight(node)
// Merger left is called whenever the min value of a node decreases.
// It handles the case of left node overlap/subsets and merging/deleting them.
// When no more overlaps are found on the left nodes it calls clean right.
let rec mergeLeft(node : Node) =
if node.Left = null then
()
elif range.min <= node.Left.Range.min - 1L then
node.Left <- node.Left.Left
mergeLeft(node)
elif range.min <= node.Left.Range.max + 1L then
range <- {min = node.Left.Range.min; max = range.max}
node.Left <- node.Left.Left
else
cleanRight(node.Left)
// Symmetric to merge left
let rec mergeRight(node : Node) =
if node.Right = null then
()
elif range.max >= node.Right.Range.max + 1L then
node.Right <- node.Right.Right
mergeRight(node)
elif range.max >= node.Right.Range.min - 1L then
range <- {min = range.min; max = node.Right.Range.max}
node.Right <- node.Right.Right
else
cleanLeft(node.Right)
let (|Before|After|BeforeOverlap|AfterOverlap|Superset|Subset|) r =
if r.min > range.max + 1L then After
elif r.min >= range.min then
if r.max <= range.max then Subset
else AfterOverlap
elif r.max < range.min - 1L then Before
elif r.max <= range.max then
if r.min >= range.min then Subset
else BeforeOverlap
else Superset
member this.Insert r =
match r with
| After ->
if right = null then
right <- Node(null, null, r)
else
right.Insert(r)
| AfterOverlap ->
range <- {min = range.min; max = r.max}
mergeRight(this)
| Before ->
if left = null then
left <- Node(null, null, r)
else
left.Insert(r)
| BeforeOverlap ->
range <- {min = r.min; max = range.max}
mergeLeft(this)
| Superset ->
range <- r
mergeLeft(this)
mergeRight(this)
| Subset -> ()
member this.Left with get() : Node = left and set(x) = left <- x
member this.Right with get() : Node = right and set(x) = right <- x
member this.Range with get() : Range = range and set(x) = range <- x
static member op_Equality (a : Node, b : Node) =
a.Range = b.Range
override this.ToString() =
sprintf "%A" this.Range
type RangeTree() =
let mutable root = null
member this.Add(range) =
if root = null then
root <- Node(null, null, range)
else
root.Insert(range)
static member fromArray(values : Range seq) =
let tree = new RangeTree()
values |> Seq.iter (fun value -> tree.Add(value))
tree
member this.Seq
with get() =
let rec inOrder(node : Node) =
seq {
if node <> null then
yield! inOrder node.Left
yield {min = node.Range.min; max = node.Range.max}
yield! inOrder node.Right
}
inOrder root
let TestRange() =
printf "\n"
let source(n) =
let rnd = new Random(n)
let rand x = rnd.NextDouble() * float x |> int64
let rangeRnd() =
let a = rand 1500
{min = a; max = a + rand 15}
[|for n in 1 .. 50 do yield rangeRnd()|]
let shuffle n (array:_[]) =
let rnd = new Random(n)
for i in 0 .. array.Length - 1 do
let n = rnd.Next(i)
let temp = array.[i]
array.[i] <- array.[n]
array.[n] <- temp
array
let testRangeAdd n i =
let dataSet1 = source (n+0)
let dataSet2 = source (n+1)
let dataSet3 = source (n+2)
let result1 = Array.concat [dataSet1; dataSet2; dataSet3] |> shuffle (i+3) |> RangeTree.fromArray
let result2 = Array.concat [dataSet2; dataSet3; dataSet1] |> shuffle (i+4) |> RangeTree.fromArray
let result3 = Array.concat [dataSet3; dataSet1; dataSet2] |> shuffle (i+5) |> RangeTree.fromArray
let test1 = (result1.Seq, result2.Seq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let test2 = (result2.Seq, result3.Seq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let test3 = (result3.Seq, result1.Seq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let print dataSet =
dataSet |> Seq.iter (fun r -> printf "%s " <| string r)
if not (test1 && test2 && test3) then
printf "\n\nTest# %A: " n
printf "\nSource 1: %A: " (n+0)
dataSet1 |> print
printf "\nSource 2: %A: " (n+1)
dataSet2 |> print
printf "\nSource 3: %A: " (n+2)
dataSet3 |> print
printf "\nResult 1: %A: " (n+0)
result1.Seq |> print
printf "\nResult 2: %A: " (n+1)
result2.Seq |> print
printf "\nResult 3: %A: " (n+2)
result3.Seq |> print
()
for i in 1 .. 10 do
for n in 1 .. 1000 do
testRangeAdd n i
printf "\n%d" (i * 1000)
printf "\nDone"
TestRange()
System.Console.ReadLine() |> ignore
Test cases for Range
After (11, 14) | | <-->
AfterOverlap (10, 14) | |<--->
AfterOverlap ( 9, 14) | +---->
AfterOverlap ( 6, 14) |<--+---->
"Test Case" ( 5, 9) +---+
BeforeOverlap ( 0, 8) <----+-->|
BeforeOverlap ( 0, 5) <----+ |
BeforeOverlap ( 0, 4) <--->| |
Before ( 0, 3) <--> | |
Superset ( 4, 10) <+---+>
Subset ( 5, 9) +---+
Subset ( 6, 8) |<->|
This is not an answer.
I adapted my test case to run against Juliet's code. It fails on a number of cases however I do see it passing some test.
type Range =
{ min : int64; max : int64 }
with
override this.ToString() =
sprintf "(%d, %d)" this.min this.max
let rangeSeqToJTree ranges =
ranges |> Seq.fold (fun tree range -> tree |> insert (range.min, range.max)) Nil
let JTreeToRangeSeq node =
let rec inOrder node =
seq {
match node with
| JNode(left, min, max, right) ->
yield! inOrder left
yield {min = min; max = max}
yield! inOrder right
| Nil -> ()
}
inOrder node
let TestJTree() =
printf "\n"
let source(n) =
let rnd = new Random(n)
let rand x = rnd.NextDouble() * float x |> int64
let rangeRnd() =
let a = rand 15
{min = a; max = a + rand 5}
[|for n in 1 .. 5 do yield rangeRnd()|]
let shuffle n (array:_[]) =
let rnd = new Random(n)
for i in 0 .. array.Length - 1 do
let n = rnd.Next(i)
let temp = array.[i]
array.[i] <- array.[n]
array.[n] <- temp
array
let testRangeAdd n i =
let dataSet1 = source (n+0)
let dataSet2 = source (n+1)
let dataSet3 = source (n+2)
let result1 = Array.concat [dataSet1; dataSet2; dataSet3] |> shuffle (i+3) |> rangeSeqToJTree
let result2 = Array.concat [dataSet2; dataSet3; dataSet1] |> shuffle (i+4) |> rangeSeqToJTree
let result3 = Array.concat [dataSet3; dataSet1; dataSet2] |> shuffle (i+5) |> rangeSeqToJTree
let test1 = (result1 |> JTreeToRangeSeq, result2 |> JTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let test2 = (result2 |> JTreeToRangeSeq, result3 |> JTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let test3 = (result3 |> JTreeToRangeSeq, result1 |> JTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let print dataSet =
dataSet |> Seq.iter (fun r -> printf "%s " <| string r)
if not (test1 && test2 && test3) then
printf "\n\nTest# %A: " n
printf "\nSource 1: %A: " (n+0)
dataSet1 |> print
printf "\nSource 2: %A: " (n+1)
dataSet2 |> print
printf "\nSource 3: %A: " (n+2)
dataSet3 |> print
printf "\n\nResult 1: %A: " (n+0)
result1 |> JTreeToRangeSeq |> print
printf "\nResult 2: %A: " (n+1)
result2 |> JTreeToRangeSeq |> print
printf "\nResult 3: %A: " (n+2)
result3 |> JTreeToRangeSeq |> print
()
for i in 1 .. 1 do
for n in 1 .. 10 do
testRangeAdd n i
printf "\n%d" (i * 10)
printf "\nDone"
TestJTree()
Got it working! I think the hardest part was figuring out how to make recursive calls on children while passing state back up the stack.
Performance is rather interesting. When inserting mainly ranges that collide and get merged together the mutable version is faster while if you insert mainly none overlapping nodes and fill out the tree the immutable version is faster. I've seen performance swing a max of 100% both ways.
Here's the complete code.
module StackOverflowQuestion
open System
type Range =
{ min : int64; max : int64 }
with
override this.ToString() =
sprintf "(%d, %d)" this.min this.max
type RangeTree =
| Node of RangeTree * int64 * int64 * RangeTree
| Nil
// Clean right deals with merging when the node to merge with is not on the
// left outside of the tree. It travels right inside the tree looking for an
// overlapping node. If it finds one it merges the range and replaces the
// node with its left child thereby deleting it. If it finds a subset node
// it replaces it with its left child, checks it and continues looking right.
let rec cleanRight n node =
match node with
| Node(left, min, max, (Node(left', min', max', right') as right)) ->
if n > max' + 1L then
let node, n' = right |> cleanRight n
Node(left, min, max, node), n'
elif n >= min' then
Node(left, min, max, left'), min'
else
Node(left, min, max, left') |> cleanRight n
| _ -> node, n
// Symmetric to clean right
let rec cleanLeft x node =
match node with
| Node(Node(left', min', max', right') as left, min, max, right) ->
if x < min' - 1L then
let node, x' = left |> cleanLeft x
Node(node, min, max, right), x'
elif x <= max' then
Node(right', min, max, right), max'
else
Node(right', min, max, right) |> cleanLeft x
| Nil -> node, x
| _ -> node, x
// Merger left is called whenever the min value of a node decreases.
// It handles the case of left node overlap/subsets and merging/deleting them.
// When no more overlaps are found on the left nodes it calls clean right.
let rec mergeLeft n node =
match node with
| Node(Node(left', min', max', right') as left, min, max, right) ->
if n <= min' - 1L then
Node(left', min, max, right) |> mergeLeft n
elif n <= max' + 1L then
Node(left', min', max, right)
else
let node, min' = left |> cleanRight n
Node(node, min', max, right)
| _ -> node
// Symmetric to merge left
let rec mergeRight x node =
match node with
| Node(left, min, max, (Node(left', min', max', right') as right)) ->
if x >= max' + 1L then
Node(left, min, max, right') |> mergeRight x
elif x >= min' - 1L then
Node(left, min, max', right')
else
let node, max' = right |> cleanLeft x
Node(left, min, max', node)
| node -> node
let (|Before|After|BeforeOverlap|AfterOverlap|Superset|Subset|) (min, max, min', max') =
if min > max' + 1L then After
elif min >= min' then
if max <= max' then Subset
else AfterOverlap
elif max < min' - 1L then Before
elif max <= max' then
if min >= min' then Subset
else BeforeOverlap
else Superset
let rec insert min' max' this =
match this with
| Node(left, min, max, right) ->
match (min', max', min, max) with
| After -> Node(left, min, max, right |> insert min' max')
| AfterOverlap -> Node(left, min, max', right) |> mergeRight max'
| Before -> Node(left |> insert min' max', min, max, right)
| BeforeOverlap -> Node(left, min', max, right) |> mergeLeft min'
| Superset -> Node(left, min', max', right) |> mergeLeft min' |> mergeRight max'
| Subset -> this
| Nil -> Node(Nil, min', max', Nil)
let rangeSeqToRangeTree ranges =
ranges |> Seq.fold (fun tree range -> tree |> insert range.min range.max) Nil
let rangeTreeToRangeSeq node =
let rec inOrder node =
seq {
match node with
| Node(left, min, max, right) ->
yield! inOrder left
yield {min = min; max = max}
yield! inOrder right
| Nil -> ()
}
inOrder node
let TestImmutableRangeTree() =
printf "\n"
let source(n) =
let rnd = new Random(n)
let rand x = rnd.NextDouble() * float x |> int64
let rangeRnd() =
let a = rand 15000
{min = a; max = a + rand 150}
[|for n in 1 .. 200 do yield rangeRnd()|]
let shuffle n (array:_[]) =
let rnd = new Random(n)
for i in 0 .. array.Length - 1 do
let n = rnd.Next(i)
let temp = array.[i]
array.[i] <- array.[n]
array.[n] <- temp
array
let print dataSet =
dataSet |> Seq.iter (fun r -> printf "%s " <| string r)
let testRangeAdd n i =
let dataSet1 = source (n+0)
let dataSet2 = source (n+1)
let dataSet3 = source (n+2)
let result1 = Array.concat [dataSet1; dataSet2; dataSet3] |> shuffle (i+3) |> rangeSeqToRangeTree
let result2 = Array.concat [dataSet2; dataSet3; dataSet1] |> shuffle (i+4) |> rangeSeqToRangeTree
let result3 = Array.concat [dataSet3; dataSet1; dataSet2] |> shuffle (i+5) |> rangeSeqToRangeTree
let test1 = (result1 |> rangeTreeToRangeSeq, result2 |> rangeTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let test2 = (result2 |> rangeTreeToRangeSeq, result3 |> rangeTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
let test3 = (result3 |> rangeTreeToRangeSeq, result1 |> rangeTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max)
if not (test1 && test2 && test3) then
printf "\n\nTest# %A: " n
printf "\nSource 1: %A: " (n+0)
dataSet1 |> print
printf "\nSource 2: %A: " (n+1)
dataSet2 |> print
printf "\nSource 3: %A: " (n+2)
dataSet3 |> print
printf "\n\nResult 1: %A: " (n+0)
result1 |> rangeTreeToRangeSeq |> print
printf "\nResult 2: %A: " (n+1)
result2 |> rangeTreeToRangeSeq |> print
printf "\nResult 3: %A: " (n+2)
result3 |> rangeTreeToRangeSeq |> print
()
for i in 1 .. 10 do
for n in 1 .. 100 do
testRangeAdd n i
printf "\n%d" (i * 10)
printf "\nDone"
TestImmutableRangeTree()
System.Console.ReadLine() |> ignore
It looks like you're defining a binary tree which is basically a union of a bunch of ranges. So, you have the following scenarios:
(10, 20) left (10, 20)
/ \ --> / \
(0, 5) (25, 30) (7, 8) (7, 8) (25, 30)
/
(0, 5)
(10, 20) right (10, 20)
/ \ --> / \
(0, 5) (25, 30) (21, 22) (0, 5) (21, 22)
\
(25, 30)
(10, 20) subset (10, 20)
/ \ --> / \
(0, 5) (25, 30) (15, 19) (0, 5) (25, 30)
(10, 20) R-superset (10, 30)
/ \ --> /
(0, 5) (25, 30) (11, 30) (0, 5)
(10, 20) L-superset (0, 20)
/ \ --> \
(0, 5) (25, 30) (0, 10) (25, 30)
(10, 20) LR-superset (0, 30)
/ \ -->
(0, 5) (25, 30) (0, 30)
The L- R- and LR-superset cases are interesting because it requires merging/deleting nodes when you insert a node whose range already contains other nodes.
The following is hastily written and not tested very well, but appears to satisfy the simple definition above:
type JTree =
| JNode of JTree * int64 * int64 * JTree
| Nil
let rec merge = function
| JNode(JNode(ll, lmin, lmax, lr), min, max, r) when min <= lmin -> merge <| JNode(ll, min, max, r)
| JNode(l, min, max, JNode(rl, rmin, rmax, rr)) when max >= rmax -> merge <| JNode(l, min, max, rr)
| n -> n
let rec insert (min, max) = function
| JNode(l, min', max', r) ->
let node =
// equal.
// e.g. Given Node(l, 10, 20, r) insert (10, 20)
if min' = min && max' = max then JNode(l, min', max', r)
// before. Insert left
// e.g. Given Node(l, 10, 20, r) insert (5, 7)
elif min' >= max then JNode(insert (min, max) l, min', max', r)
// after. Insert right
// e.g. Given Node(l, 10, 20, r) insert (30, 40)
elif max' <= min then JNode(l, min', max', insert (min, max) r)
// superset
// e.g. Given Node(l, 10, 20, r) insert (0, 40)
elif min' >= min && max' <= max then JNode(l, min, max, r)
// overlaps left
// e.g. Given Node(l, 10, 20, r) insert (5, 15)
elif min' >= min && max' >= max then JNode(l, min, max', r)
// overlaps right
// e.g. Given Node(l, 10, 20, r) insert (15, 40)
elif min' <= min && max' <= max then JNode(l, min', max, r)
// subset.
// e.g. Given Node(l, 10, 20, r) insert (15, 17)
elif min' <= min && max >= max then JNode(l, min', max', r)
// shouldn't happen
else failwith "insert (%i, %i) into Node(l, %i, %i, r)" min max min' max'
// balances left and right sides
merge node
| Nil -> JNode(Nil, min, max, Nil)
JTree = Juliet Tree :) The merge function does all the heavy lifting. It'll merge as far as possible down the left spine, then as far as possible down the right spine.
I'm not wholly convinced that my overlaps left and overlaps right cases are implemented properly, but the other cases should be correct.