树回归

树回归 优点:可以对复杂和非线性的数据建模 缺点:结果不易理解 适用数据类型: 数值型和标称型数据。 CART算法实现 bi

树回归

优点:可以对复杂和非线性的数据建模

缺点:结果不易理解

适用数据类型: 数值型和标称型数据。

CART算法实现

binSplitDataSet()函数,有三个参数:数据集合,待切分的特征和该特征的某个值。在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。

def binSplitDataSet(dataSet, feature, value):

mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]

mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]

return mat0, mat1

createTree()函数,有4个参数,数据集和其他三个可选参数,这些可选参数决定了树的类型: leafType给出了建立叶节点的函数,errorType代表误差计算函数,ops是一个包含树构建所需其它参数的元祖。

def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):

feat, val = chooseBestSplit(dataSet, leafType, errType, ops)

if feat == None: return val

retTree = {}

retTree['spInd'] = feat

retTree['spVal'] = val

lSet, rSet = binSplitDataSet(dataSet, feat, val)

retTree['left'] = createTree(lSet, leafType, errType, ops)

retTree['right'] = createTree(rSet, leafType, errType, ops)

return retTree

chooseBestSplit()函数,给定某个误差计算方法,该函数会找到数据集上的最佳二元切分方式。该函数需要完成两件事:用最佳方式切分数据集和生成相应的叶节点。

伪代码如下:

对每个特征:

对每个特征值:

将数据集切分成两份

计算切分的误差

如果当前误差小于当前最小误差,将当前切分设定为最佳切分并更新最小误差

返回最佳切分的特征和阀值

coding:

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):

tolS = ops[0]; tolN = ops[1]

#if all the target variables are the same value: quit and return value

if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1

return None, leafType(dataSet)

m,n = shape(dataSet)

#the choice of the best feature is driven by Reduction in RSS error from mean

S = errType(dataSet)

bestS = inf; bestIndex = 0; bestValue = 0

for featIndex in range(n-1):

for splitVal in set(dataSet[:,featIndex]):

mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)

if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue

newS = errType(mat0) + errType(mat1)

if newS < bestS:

bestIndex = featIndex

bestValue = splitVal

bestS = newS

#if the decrease (S-bestS) is less than a threshold don't do the split

if (S - bestS) < tolS:

return None, leafType(dataSet) #exit cond 2

mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)

if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3

return None, leafType(dataSet)

return bestIndex,bestValue#returns the best feature to split on

#and the value used for that split

函数chooseBestSplit()一开始为ops设定了tolS和tolN这两个值,它们是用户指定的参数,用于控制函数的停止时机。其中tolS是容许的误差下降值,tolN是切分的最少样本数。通过对当前所有目标变量建立一个集合,然后统计不同剩余特征值的数目,如果该数目为1,那么不需要再切分直接返回,然后函数计算了当前数据集的大小和误差,该误差S将用于和新切分误差进行对比,检查新切分能否降低误差。 如果切分数据集后效果提升不够大,那么就不进行切分操作而直接创建叶节点。

另外,还需要检验两个切分后的子集大小,如果某个子集大小小于用户定义的参数tolN,那么也不进行切分。

运行结果:

树回归

后剪枝

伪代码:

基于已有的树切分测试数据:

如果存在任一子集是一棵树,则在该子集递归剪枝过程

计算将当前两个节点合并后的误差

计算不合并的误差

如果合并会降低误差的话,就将叶节点合并。

coding:

def prune(tree, testData):

if shape(testData)[0] == 0: return getMean(tree)

if (isTree(tree['right']) or isTree(tree['left'])):

lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])

if isTree(tree['left']): tree['left'] = prune(tree['left'],lSet)

if isTree(tree['right']): tree['right'] = prune(tree['right'],rSet)

if not isTree(tree['left']) and not isTree(tree['right']):

lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])

errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + /

sum(power(rSet[:,-1] - tree['right'],2))

treeMean = (tree['left'] + tree['right']) / 2.0

errorMerge = sum(power(testData[:,-1] - treeMean,2))

if errorMerge < errorNoMerge:

print "merging"

return treeMean

else : return tree

else:return tree

函数prune()有两个参数,待剪枝的树与剪枝所需的测试集。首先确认测试集是否为空,如果非空,则反复递归调用prune()对测试数据进行切分。检查某个分支到底是子树还是节点。如果是子树,就调用函数来对子树进行剪枝。再对左右两个分支完成剪枝后,还需要检查它们是否仍然还是子树,如果已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较,如果合并后的误差比不合并的小就进行合并操作,反之不合并直接返回。

运行结果:

树回归

模型树

用树来对数据进行建模,除了把叶节点设定为常数值外,还可以将其设定为分段线性函数,分段线性(piecewise linear)即模型由多个线性片段组成。

树回归

可以设计两条分别从0.0-0.3、从0.3~1.0的直线,得到两个线性模型,即分段线性模型。

两条直线比很多节点组成一颗大树更容易理解。模型树的可解释性是它优于回归树的特点之一。模型树也具有更高的预测准确度。利用树生成算法对数据进行切分,且每份切分数据都能很容易被线性模型所表示,关键在于找到最佳切分。

def linearSolve(dataSet):

m, n = shape(dataSet)

X = mat(ones((m,n)));

Y = mat(ones((m,1)))

X[:,1:n] = dataSet[:,0:n-1];

Y = dataSet[:,-1]

XTX = X.T * X

if linalg.det(XTX) == 0.0:

raise NameError('This matrix is singular, cannot do inverse, /n/

try increasing the second value of ops')

ws = XTX.I * (X.T * Y)

return ws, X, Y

def modelLeaf(dataSet):

ws, X,Y = linearSolve(dataSet)

return ws

def modelErr(dataSet):

ws, X,Y = linearSolve(dataSet)

yHat = X * ws

return sum(power(Y - yHat, 2))

运行结果:

树回归

可以看到,该代码以0.285477为界创建了两个模型,而原图中的数据实际在0.3处分段,createTree()生成的这两个线性模型分别为:y = 3.468 + 1.1852x和0.0016985 + 11.96477x,与用于生成该睡的真是模型非常接近。

该数据实际是由模型y = 3.5 + 1.0x 和 y = 0 + 12x再加上高斯噪声生成的。

完整代码地址: https://github.com/JLUNeverMore/Tree_regression

未登录用户
全部评论0
到底啦