符号微分

lua-users home
wiki

这最初是为PenlightLibraries的一个小练习,但后来变得如此令人着迷,以至于需要更彻底的实现。

符号代数的第一步是定义一种表示方法。将表达式转换为合适的格式实际上很简单;不需要解析表达式,因为Lua可以为我们完成这项工作。使用pl.func库可以完成所有繁重的工作;它重新定义了算术运算,使其作用于*占位符表达式*(PE),这些表达式是包含称为占位符的虚拟变量的Lua表达式。pl.func定义了名为_1_2等的标准占位符参数,但Var函数将创建我们选择的新占位符。

utils.import 'pl.func'
a,b,c,d = Var 'a,b,c,d'
print(a+b+c+d)

这将以可读的形式打印出表达式。PE运算符表达式存储为类似于 {op='+',x,y} 的表的组合,这些表具有关联的元表,定义了诸如__add等元方法。作为一个树,具有Lua运算符的通常的结合律,我们得到

绘制这些图表很麻烦,因此更好的表示法是Lisp风格的S表达式。

1: (+ (+ (+ a b) c) d)

然而,通过我们将执行的各种操作,这种规范形式不是a+b+c+d的唯一可能表示。

2: (+ a (+ b (+ c d)))
3: (+ (+ a b) (+ c d))

现在,经验表明这会导致疯狂。相反,最好采用规范的Lisp表示法。

4: (+ a b c d)

一旦实现了这一点,许多操作就变得很简单,例如与(+ a c b d)进行比较,只需要对参数进行“无序比较”。以这种形式显示PE很简单。isPE通过查看元表来简单地检查表达式是否为占位符表达式。操作符为'X'的PE是占位符变量,其余的必须是表达式节点。

function sexpr (e)
	if isPE(e) then
  	  if e.op ~= 'X' then
	    local args = tablex.imap(sexpr,e)
	    return '('..e.op..' '..table.concat(args,' ')..')'
	  else
	    return e.repr
	  end
	else
	  return tostring(e)
	end
end

第一个任务是*平衡*表达式,它将表示1-3转换为4。

function balance (e)
	if isPE(e) and e.op ~= 'X' then
	  local op,args = e.op
	  if op == '+' or op == '*' then
		args = rcollect(e)
	  else
		args = imap(balance,e)
	  end
	  for i = 1,#args do
		e[i] = args[i]
	  end
	end
	return e
end

对于非交换运算符,思想就是通过将balance映射到PE的数组部分(即参数列表)来平衡所有子表达式。然后将它们复制回原处。棘手的部分是处理+和*,因为有必要收集所有来自表达式树(看起来像1、2或3)的参数,并将它们转换为第四种形式。

function tcollect (op,e,ls)
    if isPE(e) and e.op == op then
	for i = 1,#e do
	   tcollect(op,e[i],ls)
        end
    else
        ls:append(e)
        return
    end
end

function rcollect (e)
    local res = List()
    tcollect(e.op,e,res)
    return res
end

这会递归地遍历相同的运算符链(前面提到的(+ (+ ...)),并收集参数,将它们展平成n元+或*表达式。

这是一个有用的函数,它遵循相同的递归模式。

-- does this PE contain a reference to x?
function references (e,x)
    if isPE(e) then
		if e.op == 'X' then return x.repr == e.repr
		else
			return find_if(e,references,x)
		end
	else
		return false
	end
end

这里是创建n元乘积和和的函数。

function muli (args) return PE{op='*',unpack(args)} end
function addi (args) return PE{op='+',unpack(args)} end

有了这个,基本的微分规则就不难了。首先,只考虑包含该变量的子表达式。

function diff (e,x)
    if isPE(e) and references(e,x) then
		local op = e.op
        if op == 'X' then
            return 1
	else
   	    local a,b = e[1],e[2]
            if op == '+' then -- differentiation is linear
		local args = imap(diff,e,x)
		return balance(addi(args))
            elseif op == '*' then -- product rule
		local res,d,ee = {}
		for i = 1,#e do
			d = fold(diff(e[i],x))
			if d ~= 0 then
	 		  ee = {unpack(e)} -- make a copy
			  ee[i] = d
			  append(res,balance(muli(ee)))
			end
		end
		if #res > 1 then return addi(res)
		else return res[1] end
            elseif op == '^' and isnumber(b) then -- power rule
                return b*x^(b-1)
            end
        end
	else
		return 0
	end
end

表达式的和的导数是导数的和。同样,imap负责将函数递归地应用于子表达式。构造结果后,我们为了方便再次进行平衡。

乘积规则在这里以其一般形式给出,并显式检查那些结果为零的项——这是fold的工作,稍后将讨论它。

(uvw..)' = u'vw.. + uv'w... + uvw'... + ...

最后是简单的幂法则。注意结果是如何以简单明了的方式表示的,因为所有这些运算符都在PE上操作。

事实上,如果您使用形式1,即二元+和*,那么所有这些规则肯定会更清晰!但然后简化将变得无法忍受。而简化(“折叠”)是棘手的,需要正确处理。fold是一个较长的函数,所以我会分段处理它。

local op = e.op
local addmul = op == '*' or op == '+'
-- first fold all arguments
local args = imap(fold,e)
if not addmul and not find_if(args,isPE) then
  -- no placeholders in these args, we can fold the expression.
  local opfn = optable[op]
  if opfn then
    return opfn(unpack(args))
  else
   return '?'
  end
elseif addmul then

第一个if正在寻找一个子表达式没有符号的情况,即它像2*510^2;在这种情况下,常数可以完全折叠。optable(在pl.operator中定义)提供了运算符名称与实现它们的实际函数之间的映射。

elseif op == '^' then
  if args[2] == 1 then return args[1] end -- identity
  if args[2] == 0 then return 1 end
end
return PE{op=op,unpack(args)}

这个子句正在清理像x^1y^0这样的表达式,这些表达式自然地来自diff中的幂法则。一旦args被处理完毕,就可以重新组合表达式。

这个例程的大部分内容处理棘手的配对,+和*。

-- split the args into two classes, PE args and non-PE args.
local classes = List.partition(args,isPE)
local pe,npe = classes[true],classes[false]

List.partition接受一个列表和一个函数,该函数接受一个参数并返回一个单一值。结果是一个表,其中键是返回的值,值是函数返回该值的元素的列表。所以

List{1,2,3,4}:partition(function(x) return x > 2 end)
--> {false={1,2},true={3,4}}
List{'one',math.sin,10,20,{1,2}}:partition(type)
--> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }

(在数学上,这些被称为*等价类*,而partition被称为*商集*)

在这种情况下,我们想将非符号参数与符号参数分开;顺序无关紧要。非符号参数npe可以折叠成一个常数。此时,运算符身份规则可以生效,这样我们就可以删除(* 0 x)并将(+ 0 x)简化为x

最后的简化是替换重复的值,这样(* x x)应该变成(^ x 2),而(+ x x x)应该变成(* x 3)pl.tablex中的count_map可以完成这项工作。它接受一个类似列表的表和一个定义等价性的函数,并返回一个从值到其出现次数的映射,因此 count_map{'a','b','a'} {a=2,b=1}

给定这个测试函数

function testdiff (e)
  balance(e)
  e = diff(e,x)
  balance(e)
  print('+ ',e)
  e = fold(e)
  print('- ',e)
end

以及这些情况

testdiff(x^2+1)
testdiff(3*x^2)
testdiff(x^2 + 2*x^3)
testdiff(x^2 + 2*a*x^3 + x^4)
testdiff(2*a*x^3)
testdiff(x*x*x)

我们得到这些结果,这表明为什么像fold这样的东西对于处理diff的结果如此必要。

+ 	2 * x ^ 1 + 0
- 	2 * x
+ 	3 * 2 * x
- 	6 * x
+ 	2 * x ^ 1 + 2 * 3 * x ^ 2
- 	2 * x + 6 * x ^ 2
+ 	2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3
- 	6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x
+ 	2 * a * 3 * x ^ 2
- 	6 * a * x ^ 2
+ 	1 * x * x + x * 1 * x + x * x * 1
- 	x ^ 2 * 3

https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua

https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua

SteveDonovan

另请参阅


RecentChanges · preferences
编辑 · 历史
最后编辑于 2012年7月4日 上午4:41 GMT (差异)