熱線電話:13121318867

登錄
首頁精彩閱讀CART分類回歸樹算法
CART分類回歸樹算法
2015-12-03
收藏

CART分類回歸樹算法


    CART分類回歸樹算法

與上次文章中提到的ID3算法和C4.5算法類似,CART算法也是一種決策樹分類算法。CART分類回歸樹算法的本質也是對數據進行分類的,最終數據的表現形式也是以樹形的模式展現的,與ID3,C4.5算法不同的是,他的分類標準所采用的算法不同了。下面列出了其中的一些不同之處:

1、CART最后形成的樹是一個二叉樹,每個節點會分成2個節點,左孩子節點和右孩子節點,而在ID3和C4.5中是按照分類屬性的值類型進行劃分,于是這就要求CART算法在所選定的屬性中又要劃分出最佳的屬性劃分值,節點如果選定了劃分屬性名稱還要確定里面按照那個值做一個二元的劃分。

2、CART算法對于屬性的值采用的是基于Gini系數值的方式做比較,gini某個屬性的某次值的劃分的gini指數的值為:

,pk就是分別為正負實例的概率,gini系數越小說明分類純度越高,可以想象成與熵的定義一樣。因此在最后計算的時候我們只取其中值最小的做出劃分。最后做比較的時候用的是gini的增益做比較,要對分類號的數據做出一個帶權重的gini指數的計算。舉一個網上的一個例子:

比如體溫為恒溫時包含哺乳類5個、鳥類2個,則:

體溫為非恒溫時包含爬行類3個、魚類3個、兩棲類2個,則

所以如果按照“體溫為恒溫和非恒溫”進行劃分的話,我們得到GINI的增益(類比信息增益):

最好的劃分就是使得GINI_Gain最小的劃分。

通過比較每個屬性的最小的gini指數值,作為最后的結果。

3、CART算法在把數據進行分類之后,會對樹進行一個剪枝,常用的用前剪枝和后剪枝法,而常見的后剪枝發包括代價復雜度剪枝,悲觀誤差剪枝等等,我寫的此次算法采用的是代價復雜度剪枝法。代價復雜度剪枝的算法公式為:

α表示的是每個非葉子節點的誤差增益率,可以理解為誤差代價,最后選出誤差代價最小的一個節點進行剪枝。

里面變量的意思為:

是子樹中包含的葉子節點個數;

是節點t的誤差代價,如果該節點被剪枝;

r(t)是節點t的誤差率;

p(t)是節點t上的數據占所有數據的比例。

是子樹Tt的誤差代價,如果該節點不被剪枝。它等于子樹Tt上所有葉子節點的誤差代價之和。下面說說我對于這個公式的理解:其實這個公式的本質是對于剪枝前和剪枝后的樣本偏差率做一個差值比較,一個好的分類當然是分類后的樣本偏差率相較于沒分類(就是剪枝掉的時候)的偏差率小,所以這時的值就會大,如果分類前后基本變化不大,則意味著分類不起什么效果,α值的分子位置就小,所以誤差代價就小,可以被剪枝。但是一般分類后的偏差率會小于分類前的,因為偏差數在高層節點的時候肯定比子節點的多,子節點偏差數最多與父親節點一樣。

CART算法實現

首先是程序的備用數據,我是把他存在了一個文字中,通過程序進行逐行的讀?。?/span>

[java] view plaincopyprint?
  1. Rid Age Income Student CreditRating BuysComputer  
  2. 1 Youth High No Fair No  
  3. 2 Youth High No Excellent No  
  4. 3 MiddleAged High No Fair Yes  
  5. 4 Senior Medium No Fair Yes  
  6. 5 Senior Low Yes Fair Yes  
  7. 6 Senior Low Yes Excellent No  
  8. 7 MiddleAged Low Yes Excellent Yes  
  9. 8 Youth Medium No Fair No  
  10. 9 Youth Low Yes Fair Yes  
  11. 10 Senior Medium Yes Fair Yes  
  12. 11 Youth Medium Yes Excellent Yes  
  13. 12 MiddleAged Medium No Excellent Yes  
  14. 13 MiddleAged High Yes Fair Yes  
  15. 14 Senior Medium No Excellent No  

下面是主程序,里面有具體的注釋:

[java] view plaincopyprint?
  1. package DataMing_CART;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.HashMap;  
  9. import java.util.LinkedList;  
  10. import java.util.Map;  
  11. import java.util.Queue;  
  12.   
  13. import javax.lang.model.element.NestingKind;  
  14. import javax.swing.text.DefaultEditorKit.CutAction;  
  15. import javax.swing.text.html.MinimalHTMLWriter;  
  16.   
  17. /** 
  18.  * CART分類回歸樹算法工具類 
  19.  *  
  20.  * @author lyq 
  21.  *  
  22.  */  
  23. public class CARTTool {  
  24.     // 類標號的值類型  
  25.     private final String YES = "Yes";  
  26.     private final String NO = "No";  
  27.   
  28.     // 所有屬性的類型總數,在這里就是data源數據的列數  
  29.     private int attrNum;  
  30.     private String filePath;  
  31.     // 初始源數據,用一個二維字符數組存放模仿表格數據  
  32.     private String[][] data;  
  33.     // 數據的屬性行的名字  
  34.     private String[] attrNames;  
  35.     // 每個屬性的值所有類型  
  36.     private HashMap<String, ArrayList<String>> attrValue;  
  37.   
  38.     public CARTTool(String filePath) {  
  39.         this.filePath = filePath;  
  40.         attrValue = new HashMap<>();  
  41.     }  
  42.   
  43.     /** 
  44.      * 從文件中讀取數據 
  45.      */  
  46.     public void readDataFile() {  
  47.         File file = new File(filePath);  
  48.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  49.   
  50.         try {  
  51.             BufferedReader in = new BufferedReader(new FileReader(file));  
  52.             String str;  
  53.             String[] tempArray;  
  54.             while ((str = in.readLine()) != null) {  
  55.                 tempArray = str.split(" ");  
  56.                 dataArray.add(tempArray);  
  57.             }  
  58.             in.close();  
  59.         } catch (IOException e) {  
  60.             e.getStackTrace();  
  61.         }  
  62.   
  63.         data = new String[dataArray.size()][];  
  64.         dataArray.toArray(data);  
  65.         attrNum = data[0].length;  
  66.         attrNames = data[0];  
  67.   
  68.         /* 
  69.          * for (int i = 0; i < data.length; i++) { for (int j = 0; j < 
  70.          * data[0].length; j++) { System.out.print(" " + data[i][j]); } 
  71.          * System.out.print("\n"); } 
  72.          */  
  73.   
  74.     }  
  75.   
  76.     /** 
  77.      * 首先初始化每種屬性的值的所有類型,用于后面的子類熵的計算時用 
  78.      */  
  79.     public void initAttrValue() {  
  80.         ArrayList<String> tempValues;  
  81.   
  82.         // 按照列的方式,從左往右找  
  83.         for (int j = 1; j < attrNum; j++) {  
  84.             // 從一列中的上往下開始尋找值  
  85.             tempValues = new ArrayList<>();  
  86.             for (int i = 1; i < data.length; i++) {  
  87.                 if (!tempValues.contains(data[i][j])) {  
  88.                     // 如果這個屬性的值沒有添加過,則添加  
  89.                     tempValues.add(data[i][j]);  
  90.                 }  
  91.             }  
  92.   
  93.             // 一列屬性的值已經遍歷完畢,復制到map屬性表中  
  94.             attrValue.put(data[0][j], tempValues);  
  95.         }  
  96.   
  97.         /* 
  98.          * for (Map.Entry entry : attrValue.entrySet()) { 
  99.          * System.out.println("key:value " + entry.getKey() + ":" + 
  100.          * entry.getValue()); } 
  101.          */  
  102.     }  
  103.   
  104.     /** 
  105.      * 計算機基尼指數 
  106.      *  
  107.      * @param remainData 
  108.      *            剩余數據 
  109.      * @param attrName 
  110.      *            屬性名稱 
  111.      * @param value 
  112.      *            屬性值 
  113.      * @param beLongValue 
  114.      *            分類是否屬于此屬性值 
  115.      * @return 
  116.      */  
  117.     public double computeGini(String[][] remainData, String attrName,  
  118.             String value, boolean beLongValue) {  
  119.         // 實例總數  
  120.         int total = 0;  
  121.         // 正實例數  
  122.         int posNum = 0;  
  123.         // 負實例數  
  124.         int negNum = 0;  
  125.         // 基尼指數  
  126.         double gini = 0;  
  127.   
  128.         // 還是按列從左往右遍歷屬性  
  129.         for (int j = 1; j < attrNames.length; j++) {  
  130.             // 找到了指定的屬性  
  131.             if (attrName.equals(attrNames[j])) {  
  132.                 for (int i = 1; i < remainData.length; i++) {  
  133.                     // 統計正負實例按照屬于和不屬于值類型進行劃分  
  134.                     if ((beLongValue && remainData[i][j].equals(value))  
  135.                             || (!beLongValue && !remainData[i][j].equals(value))) {  
  136.                         if (remainData[i][attrNames.length - 1].equals(YES)) {  
  137.                             // 判斷此行數據是否為正實例  
  138.                             posNum++;  
  139.                         } else {  
  140.                             negNum++;  
  141.                         }  
  142.                     }  
  143.                 }  
  144.             }  
  145.         }  
  146.   
  147.         total = posNum + negNum;  
  148.         double posProbobly = (double) posNum / total;  
  149.         double negProbobly = (double) negNum / total;  
  150.         gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;  
  151.   
  152.         // 返回計算基尼指數  
  153.         return gini;  
  154.     }  
  155.   
  156.     /** 
  157.      * 計算屬性劃分的最小基尼指數,返回最小的屬性值劃分和最小的基尼指數,保存在一個數組中 
  158.      *  
  159.      * @param remainData 
  160.      *            剩余誰 
  161.      * @param attrName 
  162.      *            屬性名稱 
  163.      * @return 
  164.      */  
  165.     public String[] computeAttrGini(String[][] remainData, String attrName) {  
  166.         String[] str = new String[2];  
  167.         // 最終該屬性的劃分類型值  
  168.         String spiltValue = "";  
  169.         // 臨時變量  
  170.         int tempNum = 0;  
  171.         // 保存屬性的值劃分時的最小的基尼指數  
  172.         double minGini = Integer.MAX_VALUE;  
  173.         ArrayList<String> valueTypes = attrValue.get(attrName);  
  174.         // 屬于此屬性值的實例數  
  175.         HashMap<String, Integer> belongNum = new HashMap<>();  
  176.   
  177.         for (String string : valueTypes) {  
  178.             // 重新計數的時候,數字歸0  
  179.             tempNum = 0;  
  180.             // 按列從左往右遍歷屬性  
  181.             for (int j = 1; j < attrNames.length; j++) {  
  182.                 // 找到了指定的屬性  
  183.                 if (attrName.equals(attrNames[j])) {  
  184.                     for (int i = 1; i < remainData.length; i++) {  
  185.                         // 統計正負實例按照屬于和不屬于值類型進行劃分  
  186.                         if (remainData[i][j].equals(string)) {  
  187.                             tempNum++;  
  188.                         }  
  189.                     }  
  190.                 }  
  191.             }  
  192.   
  193.             belongNum.put(string, tempNum);  
  194.         }  
  195.   
  196.         double tempGini = 0;  
  197.         double posProbably = 1.0;  
  198.         double negProbably = 1.0;  
  199.         for (String string : valueTypes) {  
  200.             tempGini = 0;  
  201.   
  202.             posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);  
  203.             negProbably = 1 - posProbably;  
  204.   
  205.             tempGini += posProbably  
  206.                     * computeGini(remainData, attrName, string, true);  
  207.             tempGini += negProbably  
  208.                     * computeGini(remainData, attrName, string, false);  
  209.   
  210.             if (tempGini < minGini) {  
  211.                 minGini = tempGini;  
  212.                 spiltValue = string;  
  213.             }  
  214.         }  
  215.   
  216.         str[0] = spiltValue;  
  217.         str[1] = minGini + "";  
  218.   
  219.         return str;  
  220.     }  
  221.   
  222.     public void buildDecisionTree(AttrNode node, String parentAttrValue,  
  223.             String[][] remainData, ArrayList<String> remainAttr,  
  224.             boolean beLongParentValue) {  
  225.         // 屬性劃分值  
  226.         String valueType = "";  
  227.         // 劃分屬性名稱  
  228.         String spiltAttrName = "";  
  229.         double minGini = Integer.MAX_VALUE;  
  230.         double tempGini = 0;  
  231.         // 基尼指數數組,保存了基尼指數和此基尼指數的劃分屬性值  
  232.         String[] giniArray;  
  233.   
  234.         if (beLongParentValue) {  
  235.             node.setParentAttrValue(parentAttrValue);  
  236.         } else {  
  237.             node.setParentAttrValue("!" + parentAttrValue);  
  238.         }  
  239.   
  240.         if (remainAttr.size() == 0) {  
  241.             if (remainData.length > 1) {  
  242.                 ArrayList<String> indexArray = new ArrayList<>();  
  243.                 for (int i = 1; i < remainData.length; i++) {  
  244.                     indexArray.add(remainData[i][0]);  
  245.                 }  
  246.                 node.setDataIndex(indexArray);  
  247.             }  
  248.             System.out.println("attr remain null");  
  249.             return;  
  250.         }  
  251.   
  252.         for (String str : remainAttr) {  
  253.             giniArray = computeAttrGini(remainData, str);  
  254.             tempGini = Double.parseDouble(giniArray[1]);  
  255.   
  256.             if (tempGini < minGini) {  
  257.                 spiltAttrName = str;  
  258.                 minGini = tempGini;  
  259.                 valueType = giniArray[0];  
  260.             }  
  261.         }  
  262.         // 移除劃分屬性  
  263.         remainAttr.remove(spiltAttrName);  
  264.         node.setAttrName(spiltAttrName);  
  265.   
  266.         // 孩子節點,分類回歸樹中,每次二元劃分,分出2個孩子節點  
  267.         AttrNode[] childNode = new AttrNode[2];  
  268.         String[][] rData;  
  269.   
  270.         boolean[] bArray = new boolean[] { true, false };  
  271.         for (int i = 0; i < bArray.length; i++) {  
  272.             // 二元劃分屬于屬性值的劃分  
  273.             rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);  
  274.   
  275.             boolean sameClass = true;  
  276.             ArrayList<String> indexArray = new ArrayList<>();  
  277.             for (int k = 1; k < rData.length; k++) {  
  278.                 indexArray.add(rData[k][0]);  
  279.                 // 判斷是否為同一類的  
  280.                 if (!rData[k][attrNames.length - 1]  
  281.                         .equals(rData[1][attrNames.length - 1])) {  
  282.                     // 只要有1個不相等,就不是同類型的  
  283.                     sameClass = false;  
  284.                     break;  
  285.                 }  
  286.             }  
  287.   
  288.             childNode[i] = new AttrNode();  
  289.             if (!sameClass) {  
  290.                 // 創建新的對象屬性,對象的同個引用會出錯  
  291.                 ArrayList<String> rAttr = new ArrayList<>();  
  292.                 for (String str : remainAttr) {  
  293.                     rAttr.add(str);  
  294.                 }  
  295.                 buildDecisionTree(childNode[i], valueType, rData, rAttr,  
  296.                         bArray[i]);  
  297.             } else {  
  298.                 String pAtr = (bArray[i] ? valueType : "!" + valueType);  
  299.                 childNode[i].setParentAttrValue(pAtr);  
  300.                 childNode[i].setDataIndex(indexArray);  
  301.             }  
  302.         }  
  303.   
  304.         node.setChildAttrNode(childNode);  
  305.     }  
  306.   
  307.     /** 
  308.      * 屬性劃分完畢,進行數據的移除 
  309.      *  
  310.      * @param srcData 
  311.      *            源數據 
  312.      * @param attrName 
  313.      *            劃分的屬性名稱 
  314.      * @param valueType 
  315.      *            屬性的值類型 
  316.      * @parame beLongValue 分類是否屬于此值類型 
  317.      */  
  318.     private String[][] removeData(String[][] srcData, String attrName,  
  319.             String valueType, boolean beLongValue) {  
  320.         String[][] desDataArray;  
  321.         ArrayList<String[]> desData = new ArrayList<>();  
  322.         // 待刪除數據  
  323.         ArrayList<String[]> selectData = new ArrayList<>();  
  324.         selectData.add(attrNames);  
  325.   
  326.         // 數組數據轉化到列表中,方便移除  
  327.         for (int i = 0; i < srcData.length; i++) {  
  328.             desData.add(srcData[i]);  
  329.         }  
  330.   
  331.         // 還是從左往右一列列的查找  
  332.         for (int j = 1; j < attrNames.length; j++) {  
  333.             if (attrNames[j].equals(attrName)) {  
  334.                 for (int i = 1; i < desData.size(); i++) {  
  335.                     if (desData.get(i)[j].equals(valueType)) {  
  336.                         // 如果匹配這個數據,則移除其他的數據  
  337.                         selectData.add(desData.get(i));  
  338.                     }  
  339.                 }  
  340.             }  
  341.         }  
  342.   
  343.         if (beLongValue) {  
  344.             desDataArray = new String[selectData.size()][];  
  345.             selectData.toArray(desDataArray);  
  346.         } else {  
  347.             // 屬性名稱行不移除  
  348.             selectData.remove(attrNames);  
  349.             // 如果是劃分不屬于此類型的數據時,進行移除  
  350.             desData.removeAll(selectData);  
  351.             desDataArray = new String[desData.size()][];  
  352.      &

數據分析咨詢請掃描二維碼

若不方便掃碼,搜微信號:CDAshujufenxi

數據分析師資訊
更多

OK
客服在線
立即咨詢
日韩人妻系列无码专区视频,先锋高清无码,无码免费视欧非,国精产品一区一区三区无码
客服在線
立即咨詢