简易拟合

lua-users home
wiki

这里介绍一种使用 LuaMatrix 拟合曲线的方法,例如直线、抛物线和指数函数。

从以下地址下载完整包:

http://luaforge.net/projects/luamatrix/

代码和方法非常简单。

首先将 x 值放入一个表中。然后将它与 y 值连接起来。使用高斯-约旦消元法来获取变量的结果。

对于指数函数,唯一需要考虑的是如何将其线性化。这很简单,例如

y = a * x^b | ln ==> ln(y) = ln( a ) + b * ln( x )

然后可以使用 fit.linear() 再次获取变量 a 和 b。

--///////////////////--
--// Curve Fitting //--
--///////////////////--

-- v 0.2

-- Lua 5.1 compatible

-- little add-on to the matrix module, to show some curve fitting

-- http://luaforge.net/projects/LuaMatrix
-- https://lua-users.lua.ac.cn/wiki/SimpleFit

-- Licensed under the same terms as Lua itself.

-- requires matrix module
local matrix = require "matrix"

-- The Fit Table
local fit = {}

-- Note all these Algos use the Gauss-Jordan Method to caculate equation systems

-- function to get the results
local function getresults( mtx )
   assert( #mtx+1 == #mtx[1], "Cannot calculate Results" )
   mtx:dogauss()
   -- tresults
   local cols = #mtx[1]
   local tres = {}
   for i = 1,#mtx do
      tres[i] = mtx[i][cols]
   end
   return unpack( tres )
end

-- fit.linear ( x_values, y_values )
-- fit a straight line
-- model (  y = a + b * x  )
-- returns a, b
function fit.linear( x_values,y_values )
   -- x_values = { x1,x2,x3,...,xn }
   -- y_values = { y1,y2,y3,...,yn }
   
   -- values for A matrix
   local a_vals = {}
   -- values for Y vector
   local y_vals = {}

   for i,v in ipairs( x_values ) do
      a_vals[i] = { 1, v }
      y_vals[i] = { y_values[i] }
   end

   -- create both Matrixes
   local A = matrix:new( a_vals )
   local Y = matrix:new( y_vals )

   local ATA = matrix.mul( matrix.transpose(A), A )
   local ATY = matrix.mul( matrix.transpose(A), Y )

   local ATAATY = matrix.concath(ATA,ATY)

   return getresults( ATAATY )
end

-- fit.parabola ( x_values, y_values )
-- Fit a parabola
-- model (  y = a + b * x + c * x� )
-- returns a, b, c
function fit.parabola( x_values,y_values )
   -- x_values = { x1,x2,x3,...,xn }
   -- y_values = { y1,y2,y3,...,yn }

   -- values for A matrix
   local a_vals = {}
   -- values for Y vector
   local y_vals = {}

   for i,v in ipairs( x_values ) do
      a_vals[i] = { 1, v, v*v }
      y_vals[i] = { y_values[i] }
   end

   -- create both Matrixes
   local A = matrix:new( a_vals )
   local Y = matrix:new( y_vals )

   local ATA = matrix.mul( matrix.transpose(A), A )
   local ATY = matrix.mul( matrix.transpose(A), Y )

   local ATAATY = matrix.concath(ATA,ATY)

   return getresults( ATAATY )
end

-- fit.exponential ( x_values, y_values )
-- Fit exponential
-- model (  y = a * x^b )
-- returns a, b
function fit.exponential( x_values,y_values )
   -- convert to linear problem
   -- ln(y) = ln(a) + b * ln(x)
   for i,v in ipairs( x_values ) do
      x_values[i] = math.log( v )
      y_values[i] = math.log( y_values[i] )
   end

   local a,b = fit.linear( x_values,y_values )

   return math.exp(a), b
end

return fit

--///////////////--
--// chillcode //--
--///////////////--

测试代码

-- require fit
-- local fit = require "fit"
local fit = dofile( "fit.lua" )

print( "Fit a straight line " )
-- x(i) = 2  | 3  | 4  | 5
-- y(i) = 5  | 9  | 15 | 21
-- model = y = a +  b * x
-- r(i) = y(i) - ( a + b * x(i) )
local a,b = fit.linear(	{ 2,3, 4, 5 },
			{ 5,9,15,21 } )
print( "=>    y = ( "..a.." )  +  ( "..b.." ) * x")

print( "Fit a parabola " )
local a, b, c = fit.parabola(	{ 0,1,2,4,6 },
				{ 3,1,0,1,4 } )
print( "=>    y = ( "..a.." )  +  ( "..b.." ) * x  +  ( "..c.." ) * x�")

print( "Fit exponential" )
local a, b = fit.exponential( {1,  2,  3,  4,   5},
			{1,3.1,5.6,9.1,12.9} )
print( "=>    y = ( "..a.." )  *  x^( "..b.." )")


最近更改 · 偏好设置
编辑 · 历史记录
最后编辑于 2007 年 8 月 26 日下午 4:40 GMT (差异)