How to implement combinations tail-recursively? - lua

I'm teaching myself Lua by reading Ierusalimschy's Programming in Lua (4th edition), and doing the exercises. Exercise 6.5 is
Write a function that takes an array and prints all combinations of the elements in the array.
After this succinct statement the book gives a hint that makes it clear that what one is expected to do is to write a function that prints all the C(n, m) combinations of m elements from an array of n elements.
I implemented the combinations function shown below:
function combinations (array, m)
local append = function (array, item)
local copy = {table.unpack(array)}
copy[#copy + 1] = item
return copy
end
local _combinations
_combinations = function (array, m, prefix)
local n = #array
if n < m then
return
elseif m == 0 then
print(table.unpack(prefix))
return
else
local deleted = {table.unpack(array, 2, #array)}
_combinations(deleted, m - 1, append(prefix, array[1]))
_combinations(deleted, m, prefix)
end
end
_combinations(array, m, {})
end
It works OK, but it is not tail-recursive.
Can someone show me a tail-recursive function that does the same thing as combinations above does?
(For what it's worth, I am using Lua 5.3.)
NB: I realize that the exercise does not require that the function be tail-recursive. This is a requirement I have added myself, out of curiosity.
EDIT: I simplified the function slightly, but removing a couple of nested functions that were not adding much.

There is a third option, one that doesn't have a snake eating it's tail. Although recursion with tail-calls don't lead to stack overflow, I avoid doing so out of personal preference. I use a while loop and a stack that holds the information for each iteration. Within the loop you pop the next task from the stack, do the work, then push next task onto the stack. I feel it looks cleaner and it's easier to visualize the nesting.
Here is how I would translate your code into the way I would write it:
function combinations(sequence, item)
local function append(array, item)
local copy = {table.unpack(array)}
copy[#copy + 1] = item
return copy
end
local stack = {}
local node = { sequence, item, {} }
while true do
local seq = node[ 1 ]
local itm = node[ 2 ]
local pre = node[ 3 ]
local n = #seq
if itm == 0 then
print(table.unpack(pre))
elseif n < itm then
-- do nothing
else
local reserve = {table.unpack(seq, 2, #seq)}
table.insert(stack, { reserve, itm, pre })
table.insert(stack, { reserve, itm-1, append(pre, seq[ 1 ]) })
end
if #stack > 0 then
node = stack[ #stack ] -- LIFO
stack[ #stack ] = nil
else
break
end
end
end
You can use this while-loop stack/node technique for just about any recursive method. Here is an example where it's applied to printing deeply nested tables: https://stackoverflow.com/a/42062321/5113346
My version, using your input example gives the same output:
1 2 3
1 2 4
1 2 5
1 3 4
1 3 5
1 4 5
2 3 4
2 3 5
2 4 5
3 4 5.
Forgive me if it doesn't work with other passed params because I didn't try to solve the answer to the exercise but rather just rewrite the code in your original post.

OK, I think I found one way to do this:
function combinations (array, m)
local dropfirst = function (array)
return {table.unpack(array, 2, #array)}
end
local append = function (array, item)
local copy = {table.unpack(array)}
copy[#copy + 1] = item
return copy
end
local _combinations
_combinations = function (sequence, m, prefix, queue)
local n = #sequence
local newqueue
if n >= m then
if m == 0 then
print(table.unpack(prefix))
else
local deleted = dropfirst(sequence)
if n > m then
newqueue = append(queue, {deleted, m, prefix})
else
newqueue = queue
end
return _combinations(deleted, m - 1,
append(prefix, sequence[1]),
newqueue)
end
end
if #queue > 0 then
newqueue = dropfirst(queue)
local newargs = append(queue[1], newqueue)
return _combinations(table.unpack(newargs))
end
end
_combinations(sequence, m, {}, {})
end
This version is, I think, tail-recursive. Unfortunately, it does not print out the results in as nice an order as did my original non-tail-recursive version (not to mention the added complexity of the code), but one can't have everything!
EDIT: Well, no, one can have everything! The version below is tail-recursive, and prints its results in the same order as does the original non-tail-recursive version:
function combinations (sequence, m, prefix, stack)
prefix, stack = prefix or {}, stack or {}
local n = #sequence
if n < m then return end
local newargs, newstack
if m == 0 then
print(table.unpack(prefix))
if #stack == 0 then return end
newstack = droplast(stack)
newargs = append(stack[#stack], newstack)
else
local deleted = dropfirst(sequence)
if n > m then
newstack = append(stack, {deleted, m, prefix})
else
newstack = stack
end
local newprefix = append(prefix, sequence[1])
newargs = {deleted, m - 1, newprefix, newstack}
end
return combinations(table.unpack(newargs)) -- tail call
end
It uses the following auxiliary functions:
function append (sequence, item)
local copy = {table.unpack(sequence)}
copy[#copy + 1] = item
return copy
end
function dropfirst (sequence)
return {table.unpack(sequence, 2, #sequence)}
end
function droplast (sequence)
return {table.unpack(sequence, 1, #sequence - 1)}
end
Example:
> combinations({1, 2, 3, 4, 5}, 3)
1 2 3
1 2 4
1 2 5
1 3 4
1 3 5
1 4 5
2 3 4
2 3 5
2 4 5
3 4 5
Ironically, this version achieves tail-recursion by implementing its own stack, so I am not sure it is ultimately any better than the non-tail-recursive version... Then again, I guess the function's stack actually lives in the heap (right?), because Lua's tables are passed around by reference (right?), so maybe this is an improvement, after all. (Please correct me if I'm wrong!)

Related

Reliable way of getting the exact decimals from any number

I'm having problem returning spesific amount of decimal numbers from this function, i would like it to get that info from "dec" argument, but i'm stuck with this right now.
Edit: Made it work with the edited version bellow but isn't there a better way?
local function remove_decimal(t, dec)
if type(dec) == "number" then
for key, num in pairs(type(t) == "table" and t or {}) do
if type(num) == "number" then
local num_to_string = tostring(num)
local mod, d = math.modf(num)
-- find only decimal numbers
local num_dec = num_to_string:sub(#tostring(mod) + (mod == 0 and num < 0 and 3 or 2))
if dec <= #num_dec then
-- return amount of deciamls in the num by dec
local r = d < 0 and "-0." or "0."
local r2 = r .. num_dec:sub(1, dec)
t[key] = mod + tonumber(r2)
end
end
end
end
return t
end
By passing the function bellow i want a result like this:
result[1] > 0.12
result[2] > -0.12
result[3] > 123.45
result[4] > -1.23
local result = remove_decimal({0.123, -0.123, 123.456, -1.234}, 2)
print(result[1])
print(result[2])
print(result[3])
print(result[4])
I tried this but it seems to only work with one integer numbers and if number is 12.34 instead of 1.34 e.g, the decimal place will be removed and become 12.3. Using other methods
local d = dec + (num < 0 and 2 or 1)
local r = tonumber(num_to_string:sub(1, -#num_to_string - d)) or 0
A good approach is to find the position of the decimal point (the dot, .) and then extract a substring starting from the first character to the dot's position plus how many digits you want:
local function truncate(number, dec)
local strnum = tostring(number)
local i, j = string.find(strnum, '%.')
if not i then
return number
end
local strtrn = string.sub(strnum, 1, i+dec)
return tonumber(strtrn)
end
Call it like this:
print(truncate(123.456, 2))
print(truncate(1234567, 2))
123.45
1234567
To bulk-truncate a set of numbers:
local function truncate_all(t, dec)
for key, value in pairs(t) do
t[key] = truncate(t[key], dec)
end
return t
end
Usage:
local result = truncate_all({0.123, -0.123, 123.456, -1.234}, 2)
for key, value in pairs(result) do
print(key, value)
end
1 0.12
2 -0.12
3 123.45
4 -1.23
One could use the function string.format which is similar to the printf functions from C language. If one use the format "%.2f" the resulting string will contain 2 decimals, if one use "%.3f" the resulting string will be contain 3 decimals, etc. The idea is to dynamically create the format "%.XXXf" corresponding to the number of decimal needed by the function. Then call the function string.format with the newly created format string to generate the string "123.XXX". The last step would be to convert back the string to a number with the function tonumber.
Note that if one want the special character % to be preserved when string.format is called, you need to write %%.
function KeepDecimals (Number, DecimalCount)
local FloatFormat = string.format("%%.%df", DecimalCount)
local String = string.format(FloatFormat, Number)
return tonumber(String)
end
The behavior seems close to what the OP is looking for:
for Count = 1, 5 do
print(KeepDecimals(1.123456789, Count))
end
This code should print the following:
1.1
1.12
1.123
1.1235
1.12346
Regarding the initial code, it's quite straight-forward to integrate the provided solution. Note that I renamed the function to keep_decimal because in my understanding, the function will keep the requested number of decimals, and discard the rest.
function keep_decimal (Table, Count)
local NewTable = {}
local NewIndex = 1
for Index = 1, #Table do
NewTable[NewIndex] = KeepDecimal(Table[Index], Count)
NewIndex = NewIndex + 1
end
return NewTable
end
Obviously, the code could be tested easily, simply by copy and pasting into a Lua interpreter.
Result = keep_decimal({0.123, -0.123, 123.456, -1.234}, 2)
for Index = 1, #Result do
print(Result[Index])
end
This should print the following:
0.12
-0.12
123.46
-1.23
Edit due to the clarification of the need of truncate:
function Truncate (Number, Digits)
local Divider = Digits * 10
local TruncatedValue = math.floor(Number * Divider) / Divider
return TruncatedValue
end
On my computer, the code is working as expected:
> Truncate(123.456, 2)
123.45

Stack Overflow on Lua metatable

I used to have a construct that worked with luajit:
mytbl = setmetatable({1}, {__index = function(tbl,idx) return tbl[idx - 1] + 1 end})
Now with plain Lua 5.4 this gives me a stack overflow:
> mytbl[1000]
stdin:1: C stack overflow
stack traceback:
stdin:1: in metamethod 'index'
....
The goal is to have a table where the default is to return the index itself:
mytbl[10]
should return 10. But when I say
mytbl[3] = 5
the value of
mytbl[10]
should be 12 (the values from 1 now yield 1,2,5,6,7,8,9,10,11,12,...)
Is there a way to get this in Lua 5.4 without the stack overflow? Or should I create another function for it?
You are accessing the table within the __index. That causes another __index to be called and so on.
Use rawget when you want to access the tbl itself.
If you want to get a custom logic that does not rely on existence of elements, write your function in a way that allows for trailing recursion or write it iteratively without any recursion at all:
__index=function(tbl, idx)
local acc = 0
for i=idx-1, 1, -1 do
local th = rawget(tbl, i)
if th then
return acc + th + 1
else
acc = acc + 1
end
end
return acc
end
This is what I came up with now:
__index=function(tbl, idx)
local max = 0
for k, v in next, tbl do
if k <= idx then max = v - k end
end
return idx + max
end
I only have very few entries in tbl so this should be reasonable fast for my purposes.
My test:
mytbl[5] = 9
for i = 1, 10 do
print(mytbl[i])
end
outputs
1
2
3
4
9
10
11
12
13
14

lua: piggybacking multiple variables across 1 if statement

I have around 7+ variables: a=1, b=10, c=12...etc
I need to write an if statement for each that does this:
if var>0 then var-=1 end
If I need each of the variables to record their values after each iteration, is there a way for me to avoid writing out one if statement per variable?
I tried defining them all in a table like:
a=1;b=2;c=3
local t = {a,b,c}
for _,v in pairs(t) do
if v>0 then v-=1 end
end
a,b,c=t[1],t[2],t[3]
This code failed though, not sure why. Ultimately I'm looking for more efficient way than simply writing the ifs. You can define efficient in terms of cost or tokens or both. The values used would be random, no pattern. The variable names can potentially be changed, i.e. a_1,a_2, a_3, its not ideal though.
There are a couple of solutions. To shorten your code, you could write a function that processes the value and run it on each variable:
local function toward0(var)
if var > 0 then
return var - 1
end
return var
end
a = toward0(a)
b = toward0(b)
c = toward0(c)
You could also store the data in a table instead of in variables. Then you can process them in a loop:
local valuesThatNeedToBeDecremented = {a = 1, b = 10, c = 12}
for k, v in pairs(valuesThatNeedToBeDecremented) do
if v > 0 then
valuesThatNeedToBeDecremented[k] = v - 1
end
end
You forgot to reasign the new values to the table!
local a, b, c = 1, 2, 3
local t = {a, b, c}
for k, v in ipairs(t) do
if v > 0 then v -= 1 end
t[k] = v
end
a, b, c = t[1], t[2], t[3]
print(a, b, c)

How to iterate through results in sets?

In my Tabletop Simulator mod I have a bag, when something is dropped in the bag the emptyContents() function is called. For example I can drop 15 dice in the bag.
In the emptyContents() function I iterate over the objects in the bag. But as you can see I have to put in multiple if statements to catch the amount of dice put in because I want the dice to be spawned on different positions.
The contents variable is the amount of dice in the bag.
function emptyContents()
contents = self.getObjects()
for i, _ in ipairs(self.getObjects()) do
if i <= 6 then
self.takeObject(setPosition(5, -3))
elseif i <= 12 then
self.takeObject(setPosition(12.4,-5))
elseif i <= 18 then
self.takeObject(setPosition(19.8,-7))
end
end
end
How can I make the function less static? Because now I need to write if statements for each set of 6 dice.
maybe you can add a config like this:
local t = {
{6, 5, -3},
{12, 12.4, -5},
{18, 19.8, -7},
}
function emptyContents()
contents = self.getObjects()
for i, _ in ipairs(self.getObjects()) do
for _, v in ipairs(t) do
local l, p1, p2 = unpack(v)
if i <= l then
self.takeObject(setPosition(p1, p2))
break
end
end
end
end

Explain why unpack() returns different results in Lua

The following script finds prime numbers in a range from 1 to 13.
When I explicitly iterate over the table that contains the results I can see that the script works as expected. However, if I use unpack() function on the table only the first 3 numbers get printed out.
From docs: unpack is "a special function with multiple returns. It receives an array and returns as results all elements from the array, starting from index 1".
Why is it not working in the script below?
t = {}
for i=1, 13 do t[i] = i end
primes = {}
for idx, n in ipairs(t) do
local isprime = true
for i=2, n-1 do
if n%i == 0 then
isprime = false
break
end
end
if isprime then
primes[idx] = n
end
end
print('loop printing:')
for i in pairs(primes) do
print(i)
end
print('unpack:')
print(unpack(primes))
Running
$ lua5.3 primes.lua
loop printing:
1
2
3
5
7
13
11
unpack:
1 2 3
Change
primes[idx] = n
to
primes[#primes+1] = n
The reason is that idx is not sequential as not every number is a prime.

Resources