why do continuations avoid stackoverflow? - f#

I've been trying to understand continuations / CPS and from what I can gather it builds up a delayed computation, once we get to the end of the list we invoke the final computation.
What I don't understand is why CPS prevents stackoverflow when it seems analogous to building up a nested function as per the naive approach in Example 1. Sorry for the long post but tried to show the idea (and possibly where it goes wrong) from basics:
So:
let list1 = [1;2;3]
Example 1: "Naive approach"
let rec sumList = function
|[] -> 0
|h::t -> h + sumList t
So when this runs, iteratively it results in:
1 + sumList [2;3]
1 + (2 + sumList [3])
1 + (2 + (3 + 0))
So the nesting (and overflow issues) can be overcome by Tail Recursion - running an accumulator i.e.
"Example 2: Tail Recursion"
let sumListACC lst =
let rec loop l acc =
match l with
|[] -> acc
|h::t -> loop t (h + acc)
loop lst 0
i.e,
sumList[2;3] (1+0)
sumList[3] (2+1)
sumList[] (3+3)
So because the accumulator is evaluated at each step, there is no nesting and we avoid bursting the stack. Clear!
Next comes CPS, I understand this is required when we already have an accumulator but the function is not tail recursive e.g. with Foldback. Although not required in the above example, applying CPS to this problem gives:
"Example 3: CPS"
let sumListCPS lst =
let rec loop l cont =
match l with
|[] -> cont 0
|h::t -> loop t (fun x -> cont( h + x))
loop lst (fun x -> x)
To my understanding, iteratively this could be written as:
loop[2;3] (fun x -> cont (1+x))
loop[3] (fun x ->cont (1+x) -> cont(2+x))
loop[] (fun x -> cont (1+x) -> cont(2+x) -> cont (3+x)
which then reduces sequentially from the right with the final x = 0 i.e:
cont(1+x)-> cont(2+x) -> cont (3+0)
cont(1+x)-> cont(2+x) -> 3
cont(1+x) -> cont (2+3)
...
cont (1+5) -> 6
which I suppose is analogous to:
cont(1+cont(2+cont(3+0)))
(1+(2+(3+0)))
correction to original post - realised that it is evaluated from the right, as for example replacing cont(h +x) with cont(h+2*x) yields 17 for the above example consistent with: (1+2*(2+2*(3+2*0)))
i.e. exactly where we started in example 1, based on this since we still need to keep track of where we came from why does using it prevent the overflow issue that example 1 suffers from?
As I know it doesn't, where have I gone wrong?
I've read the following posts (multiple times) but the above confusion remains.
http://www.markhneedham.com/blog/2009/06/22/f-continuation-passing-style/
http://codebetter.com/matthewpodwysocki/2008/08/13/recursing-on-recursion-continuation-passing/
http://lorgonblog.wordpress.com/2008/04/05/catamorphisms-part-one/

What happens is quite simple.
.NET (and other platforms, but we're discussing F# right now) stores information in two locations: the stack (for value types, for pointer to objects, and for keeping track of function calls) and the heap (for objects).
In regular non-tail recursion, you keep track of your progress in the stack (quite obviously). In CPS, you keep track of your progress in lambda functions (which are on the heap!), and tail recursion optimization makes sure that the stack stays clear of any tracking.
As the heap is significantly larger than the stack, it is (in some cases) better to move the tracking from the stack to the heap - via CPS.

Related

How to stop the traversal of a list

I start in Fsharp and I have this question.
Assuming I have two lists a and b of the same length, I travel theses lists simultaneously and test a condition on a and b each step, with the result of a previous calculus. If this test fails there is no need to keep on.
I wrote this code :
let mutable (i : int) = 0
let mutable (good : bool) = true
let mutable (previous : int) = 0
while good && i < len do
good <- test a.[i] b.[i] previous
previous <- my_func a.[i] b.[i]
i <- i + 1
I saw this code which is much much better :
List.zip a b |> List.fold (fun (x, y) (a,b) -> (p && test a b y, my_func a b) (true, 0)
But, with my code, as soon as the test fails, the process finished, not with the second code.
Is there a way, using the design of the second code to stop the process ?
Thank you
I assume you are only interested in whether the final result is good.
As mentioned by Brian, you can use Seq.scan which behaves like Seq.fold but it returns all the intermediate states rather than just the final state. By using Seq instead of List you are also using a lazy sequence and so functions can terminate early. To do what you want, you can use Seq.scan together with Seq.forall, which will check that all values of a given sequence satisfy a certain condition - the nice thing here is that this can terminate early as soon as the condition is false.
Putting all this together, I get something like this:
Seq.zip a b
|> Seq.scan (fun (good, prev) (a, b) ->
test a b prev, my_func a b) (true, 0)
|> Seq.forall (fun (good, _) -> good)
Three ideas:
You could write your own variation of fold that stops when a flag is set to false. Personally, I think that would be cumbersome, though.
You could use Seq.scan to lazily accumulate all the results of my_func and then examine them, similar to this answer. Since Seq is lazy, the scan would short circuit the way you want. This is tricky to get right, though.
You could walk the lists recursively. IMHO, this is the simplest functional solution. Something like this:
let rec examine listA listB prev =
match listA, listB with
| headA :: tailA, headB :: tailB ->
if test headA headB prev then
let prev' = my_func headA headB
examine tailA tailB prev'
else false
| [], [] -> true
| _ -> false
examine listA listB 0

How to write efficient list/seq functions in F#? (mapFoldWhile)

I was trying to write a generic mapFoldWhile function, which is just mapFold but requires the state to be an option and stops as soon as it encounters a None state.
I don't want to use mapFold because it will transform the entire list, but I want it to stop as soon as an invalid state (i.e. None) is found.
This was myfirst attempt:
let mapFoldWhile (f : 'State option -> 'T -> 'Result * 'State option) (state : 'State option) (list : 'T list) =
let rec mapRec f state list results =
match list with
| [] -> (List.rev results, state)
| item :: tail ->
let (result, newState) = f state item
match newState with
| Some x -> mapRec f newState tail (result :: results)
| None -> ([], None)
mapRec f state list []
The List.rev irked me, since the point of the exercise was to exit early and constructing a new list ought to be even slower.
So I looked up what F#'s very own map does, which was:
let map f list = Microsoft.FSharp.Primitives.Basics.List.map f list
The ominous Microsoft.FSharp.Primitives.Basics.List.map can be found here and looks like this:
let map f x =
match x with
| [] -> []
| [h] -> [f h]
| (h::t) ->
let cons = freshConsNoTail (f h)
mapToFreshConsTail cons f t
cons
The consNoTail stuff is also in this file:
// optimized mutation-based implementation. This code is only valid in fslib, where mutation of private
// tail cons cells is permitted in carefully written library code.
let inline setFreshConsTail cons t = cons.(::).1 <- t
let inline freshConsNoTail h = h :: (# "ldnull" : 'T list #)
So I guess it turns out that F#'s immutable lists are actually mutable because performance? I'm a bit worried about this, having used the prepend-then-reverse list approach as I thought it was the "way to go" in F#.
I'm not very experienced with F# or functional programming in general, so maybe (probably) the whole idea of creating a new mapFoldWhile function is the wrong thing to do, but then what am I to do instead?
I often find myself in situations where I need to "exit early" because a collection item is "invalid" and I know that I don't have to look at the rest. I'm using List.pick or Seq.takeWhile in some cases, but in other instances I need to do more (mapFold).
Is there an efficient solution to this kind of problem (mapFoldWhile in particular and "exit early" in general) with functional programming concepts, or do I have to switch to an imperative solution / use a Collections.Generics.List?
In most cases, using List.rev is a perfectly sufficient solution.
You are right that the F# core library uses mutation and other dirty hacks to squeeze some more performance out of the F# list operations, but I think the micro-optimizations done there are not particularly good example. F# list functions are used almost everywhere so it might be a good trade-off, but I would not follow it in most situations.
Running your function with the following:
let l = [ 1 .. 1000000 ]
#time
mapFoldWhile (fun s v -> 0, s) (Some 1) l
I get ~240ms on the second line when I run the function without changes. When I just drop List.rev (so that it returns the data in the other order), I get around ~190ms. If you are really calling the function frequently enough that this matters, then you'd have to use mutation (actually, your own mutable list type), but I think that is rarely worth it.
For general "exit early" problems, you can often write the code as a composition of Seq.scan and Seq.takeWhile. For example, say you want to sum numbers from a sequence until you reach 1000. You can write:
input
|> Seq.scan (fun sum v -> v + sum) 0
|> Seq.takeWhile (fun sum -> sum < 1000)
Using Seq.scan generates a sequence of sums that is over the whole input, but since this is lazily generated, using Seq.takeWhile stops the computation as soon as the exit condition happens.

How to combine state and continuation monads in F#

I'm trying to sum a tree using the Task Parallel Library where child tasks are spawned only until the tree is traversed until a certain depth, and otherwise it sums the remaining child nodes using continuation passing style, to avoid stack overflows.
However, the code looks pretty ugly - it would be nice to use a state monad to carry the current depth around, but the state monad isn't tail recursive. Alternatively, how would I modify the continuation monad to carry around the state? Or create a combination of the state and continuation monads?
let sumTreeParallelDepthCont tree cont =
let rec sumRec tree depth cont =
let newDepth = depth - 1
match tree with
| Leaf(num) -> cont num
| Branch(left, right) ->
if depth <= 0 then
sumTreeContMonad left (fun leftM ->
sumTreeContMonad right (fun rightM ->
cont (leftM + rightM )))
else
let leftTask = Task.Factory.StartNew(fun () ->
let leftResult = ref 0
sumRec left newDepth (fun leftM ->
leftResult := leftM)
!leftResult
)
let rightTask = Task.Factory.StartNew(fun () ->
let rightResult = ref 0
sumRec right newDepth (fun rightM ->
rightResult := rightM)
!rightResult
)
cont (leftTask.Result + rightTask.Result)
sumRec tree 4 cont // 4 levels deep
I've got a little more detail on this blog post: http://taumuon-jabuka.blogspot.co.uk/2012/06/more-playing-with-monads.html
I think it is important to first understand what your requirements are.
The sequential version of the algorithm does not need to keep the depth (because it always processes the rest of the tree). However, it needs to use continuations because the tree can be large.
The parallel version, on the other hand, needs to keep the depth (because you only want to make limited number of recursive calls), but it does not need to use continuations (because the depth is quite limited and when you start a new task, it does not keep the stack anyway).
This means that you don't really need to combine the two aspects at all. Then you can rewrite the parallel version in a quite straightforward way:
let sumTreeParallelDepthCont tree =
let rec sumRec tree depth =
match tree with
| Leaf(num) -> num
| tree when depth <= 0 ->
sumTreeContMonad tree id
| Branch(left, right) ->
let leftTask = Task.Factory.StartNew(fun () -> sumRec left (depth + 1))
let rightResult = sumRec right (depth + 1)
leftTask.Result + rightResult
sumRec tree 4 // 4 levels deep
There is no need to duplicate the code from sumTreeContMonad because you can just call it on the current tree in the case tree when depth <= 0.
This also avoids using reference cells by creating Task<int> instead of Task and I modified the algorithm to only spawn one background task and do the second part of the work on the current thread.
In my eyes, the depth looks fine, the ugly bit is the ref cells and assignments. I am unclear why you need them; I think just passing id (identity function) as the cont parameter means that sumRec will return the value, and then you won't need the ref cells. (I may be wrong, this is analysis-at-a-glance.)
(I also would get rid of newDepth and just inline (depth-1) at the recursive call sites, as a matter of style.)
Finally, I've no idea what sumTreeContMonad is, but it appears that you could just use sumRec t -1 k instead of sumTreeContMonad t k and it would work the same.
(If your blog had code, rather than pictures of code, I might just post my own code with these refinements, but I don't feel like transcribing the data types and such. Why post pictures?)

How to efficiently find out if a sequence has at least n items?

Just naively using Seq.length may be not good enough as will blow up on infinite sequences.
Getting more fancy with using something like ss |> Seq.truncate n |> Seq.length will work, but behind the scene would involve double traversing of the argument sequence chunk by IEnumerator's MoveNext().
The best approach I was able to come up with so far is:
let hasAtLeast n (ss: seq<_>) =
let mutable result = true
use e = ss.GetEnumerator()
for _ in 1 .. n do result <- e.MoveNext()
result
This involves only single sequence traverse (more accurately, performing e.MoveNext() n times) and correctly handles boundary cases of empty and infinite sequences. I can further throw in few small improvements like explicit processing of specific cases for lists, arrays, and ICollections, or some cutting on traverse length, but wonder if any more effective approach to the problem exists that I may be missing?
Thank you for your help.
EDIT: Having on hand 5 overall implementation variants of hasAtLeast function (2 my own, 2 suggested by Daniel and one suggested by Ankur) I've arranged a marathon between these. Results that are tie for all implementations prove that Guvante is right: a simplest composition of existing algorithms would be the best, there is no point here in overengineering.
Further throwing in the readability factor I'd use either my own pure F#-based
let hasAtLeast n (ss: seq<_>) =
Seq.length (Seq.truncate n ss) >= n
or suggested by Ankur the fully equivalent Linq-based one that capitalizes on .NET integration
let hasAtLeast n (ss: seq<_>) =
ss.Take(n).Count() >= n
Here's a short, functional solution:
let hasAtLeast n items =
items
|> Seq.mapi (fun i x -> (i + 1), x)
|> Seq.exists (fun (i, _) -> i = n)
Example:
let items = Seq.initInfinite id
items |> hasAtLeast 10000
And here's an optimally efficient one:
let hasAtLeast n (items:seq<_>) =
use e = items.GetEnumerator()
let rec loop n =
if n = 0 then true
elif e.MoveNext() then loop (n - 1)
else false
loop n
Functional programming breaks up work loads into small chunks that do very generic tasks that do one simple thing. Determining if there are at least n items in a sequence is not a simple task.
You already found both the solutions to this "problem", composition of existing algorithms, which works for the majority of cases, and creating your own algorithm to solve the issue.
However I have to wonder whether your first solution wouldn't work. MoveNext() is only called n times on the original method for certain, Current is never called, and even if MoveNext() is called on some wrapper class the performance implications are likely tiny unless n is huge.
EDIT:
I was curious so I wrote a simple program to test out the timing of the two methods. The truncate method was quicker for a simple infinite sequence and one that had Sleep(1). It looks like I was right when your correction sounded like overengineering.
I think clarification is needed to explain what is happening in those methods. Seq.truncate takes a sequence and returns a sequence. Other than saving the value of n it doesn't do anything until enumeration. During enumeration it counts and stops after n values. Seq.length takes an enumeration and counts, returning the count when it ends. So the enumeration is only enumerated once, and the amount of overhead is a couple of method calls and two counters.
Using Linq this would be as simple as:
let hasAtLeast n (ss: seq<_>) =
ss.Take(n).Count() >= n
Seq take method blows up if there are not enough elements.
Example usage to show it traverse seq only once and till required elements:
seq { for i = 0 to 5 do
printfn "Generating %d" i
yield i }
|> hasAtLeast 4 |> printfn "%A"

Return item at position x in a list

I was reading this post While or Tail Recursion in F#, what to use when? were several people say that the 'functional way' of doing things is by using maps/folds and higher order functions instead of recursing and looping.
I have this function that returns the item at position x in a list:
let rec getPos l c = if c = 0 then List.head l else getPos (List.tail l) (c - 1)
how can it be converted to be more functional?
This is a primitive list function (also known as List.nth).
It is okay to use recursion, especially when creating the basic building blocks. Although it would be nicer with pattern matching instead of if-else, like this:
let rec getPos l c =
match l with
| h::_ when c = 0 -> h
| _::t -> getPos t (c-1)
| [] -> failwith "list too short"
It is possible to express this function with List.fold, however the result is less clear than the recursive version.
I'm not sure what you mean by more functional.
Are you rolling this yourself as a learning exercise?
If not, you could just try this:
> let mylist = [1;2;3;4];;
> let n = 2;;
> mylist.[n];;
Your definition is already pretty functional since it uses a tail-recursive function instead of an imperative loop construct. However, it also looks like something a Scheme programmer might have written because you're using head and tail.
I suspect you're really asking how to write it in a more idiomatic ML style. The answer is to use pattern matching:
let rec getPos list n =
match list with
| hd::tl ->
if n = 0 then hd
else getPos tl (n - 1)
| [] -> failWith "Index out of range."
The recursion on the structure of the list is now revealed in the code. You also get a warning if the pattern matching is non-exhaustive so you're forced to deal with the index too big error.
You're right that functional programming also encourages the use of combinators like map or fold (so called points-free style). But too much of it just leads to unreadable code. I don't think it's warranted in this case.
Of course, Benjol is right, in practice you would just write mylist.[n].
If you'd like to use high-order functions for this, you could do:
let nth n = Seq.take (n+1) >> Seq.fold (fun _ x -> Some x) None
let nth n = Seq.take (n+1) >> Seq.reduce (fun _ x -> x)
But the idea is really to have basic constructions and combine them build whatever you want. Getting the nth element of a sequence is clearly a basic block that you should use. If you want the nth item, as Benjol mentioned, do myList.[n].
For building basic constructions, there's nothing wrong to use recursion or mutable loops (and often, you have to do it this way).
Not as a practical solution, but as an exercise, here is one of the ways to express nth via foldr or, in F# terms, List.foldBack:
let myNth n xs =
let step e f = function |0 -> e |n -> f (n-1)
let error _ = failwith "List is too short"
List.foldBack step xs error n

Resources