辛煌炜的其它学习记录
-
LeetCode 动态规划
一、 乘积最大子数组
1. 题目
给你一个整数数组
nums
,请你找出数组中乘积最大的连续子数组(该子数组中至少包含一个数字),并返回该子数组所对应的乘积。2. 输入输出
输入: [2,3,-2,4] 输出: 6 解释: 子数组 [2,3] 有最大乘积 6。
输入: [-2,0,-1] 输出: 0 解释: 结果不能为 2, 因为 [-2,-1] 不是子数组。
3. 解析
建立两个数组分别用来记录从0到k中的最大乘积和最小乘积(0 <= k < nums.length)
当前索引k下的最大乘积可能有三种取值方式
一种是当前索引k下的数组值 nums[k]
一种是当前数组值乘以索引 k-1 的最大乘积(一般两者都为正数)
一种是当前数组值乘以索引 k-1 的最小乘积(例如两者都为负数,乘积反而为最大值)
故每次都需要记录索引 k 下的最大和最小乘积
4. 解答
class Solution { public int maxProduct(int[] nums) { int max = nums[0] ; int[] data_max = new int[nums.length] ; int[] data_min = new int[nums.length] ; data_max[0] = nums[0] ; data_min[0] = nums[0] ; for(int i = 1 ; i < nums.length ; i++){ data_max[i] = maxInThree(nums[i], nums[i] * data_max[i-1], nums[i] * data_min[i-1]) ; data_min[i] = minInThree(nums[i], nums[i] * data_max[i-1], nums[i] * data_min[i-1]) ; max = maxInThree(max, data_max[i], data_min[i]) ; } return max ; } public int maxInThree(int a, int b, int c){ int m = a > b ? a : b ; int n = m > c ? m : c ; return n ; } public int minInThree(int a, int b, int c){ int m = a < b ? a : b ; int n = m < c ? m : c ; return n ; } }
-
二、Best Time to Buy and Sell Stock with Cooldown
1. 题目
给定一个整数数组,其中第 i 个元素代表了第 i 天的股票价格 。
设计一个算法计算出最大利润。在满足以下约束条件下,你可以尽可能地完成更多的交易(多次买卖一支股票):
- 你不能同时参与多笔交易(你必须在再次购买前出售掉之前的股票)。
- 卖出股票后,你无法在第二天买入股票 (即冷冻期为 1 天)。
2. 输入输出
输入: [1,2,3,0,2] 输出: 3 解释: 对应的交易状态为: [买入, 卖出, 冷冻期, 买入, 卖出]
3. 解析
对于每一天一只股票可以有三种状态:买入、卖出、冻结期
所以我们可以建立三个数组来表示每一天假如处于每个状态的最大收益是多少
假设数组b表示buy买入,数组s表示sell卖出,数组r表示reset冻结,数组prices表示价格
我们可以得到三者之间关系式
-
b[i] = r[i-1] 第i天买入则前一天必为冻结
-
s[i] = max{ s[i-1] + prices[i] - prices[i-1],b[i-1] + prices[i] - prices[i-1]}
第i天卖出则前一天可能为买入也可能为卖出
-
r[i] = max{ s[i-1], r[i-1]} 第i天冻结则前一天可能为买入也可能为卖出
4. 解答
public int maxProfit(int[] prices) { if(prices.length == 0){ return 0 ; } int[] b = new int[prices.length] ; int[] s = new int[prices.length] ; int[] r = new int[prices.length] ; b[0] = 0 ; s[0] = 0 ; r[0] = 0 ; for(int i = 1 ; i < prices.length ; i++){ b[i] = r[i-1] ; s[i] = Math.max(s[i-1] + prices[i] - prices[i-1], b[i-1] + prices[i] - prices[i-1]) ; r[i] = Math.max(s[i-1], r[i-1]) ; } //此时最后一天的状态一定在三种数组的状态之中,所以我们直接比较最后一天的数组结果最大值即可 return Math.max(b[prices.length-1], Math.max(s[prices.length-1], r[prices.length-1])) ; }
假如把三个表达式消去数组r也可以得到两个表达式的计算结果
- s[i] = max{ s[i-1] + prices[i] - prices[i-1],b[i-1] + prices[i] - prices[i-1]}
- b[i] = max{ s[i-2],b[i-1] }
public int maxProfit(int[] prices) { if(prices.length == 0){ return 0 ; } int[] b = new int[prices.length] ; int[] s = new int[prices.length] ; b[0] = 0 ; s[0] = 0 ; int max = 0 ; for(int i = 1 ; i < prices.length ; i++){ s[i] = Math.max(s[i-1] + prices[i] - prices[i-1], b[i-1] + prices[i] - prices[i-1]) ; b[i] = (i == 1)? 0 : Math.max(s[i-2], b[i-1]) ; max = Math.max(max, Math.max(s[i], b[i])) ; } //由于去掉了冻结期的数组状态,当最后一天是冻结期时无法直接表示 //所以我们通过边计算数组,边比较最大值的方式得到结果 return max ; }
-
三、Perfect Squares
1. 题目
给定正整数 n,找到若干个完全平方数(比如
1, 4, 9, 16, ...
)使得它们的和等于 n。你需要让组成和的完全平方数的个数最少。2. 输入输出
输入: n = 12 输出: 3 解释: 12 = 4 + 4 + 4.
输入: n = 13 输出: 2 解释: 13 = 4 + 9.
3. 解析
维护一个数组用来表示所有数字的最少平方数组合次数
循环每一个数字从1、2、3、...的平方开始拆解
并且与已经计算出的数字的组合次数相加得出新的组合结果
4. 解答
public int numSquares(int n) { int[] data = new int[n+1] ; data[0] = 0 ; for(int i = 1 ; i < n+1 ; i++){ data[i] = Integer.MAX_VALUE ; for(int j = 1 ; j*j <= i ; j++){ int min = 1 + data[i - j*j] ; if(min < data[i]){ data[i] = min ; } } } return data[n] ; }
-
四、单词拆分
1. 题目
给定一个非空字符串 s 和一个包含非空单词列表的字典 wordDict,判定 s 是否可以被空格拆分为一个或多个在字典中出现的单词。
说明:
- 拆分时可以重复使用字典中的单词。
- 你可以假设字典中没有重复的单词。
2. 输入输出
示例 1:
输入: s = "leetcode", wordDict = ["leet", "code"] 输出: true 解释: 返回 true 因为 "leetcode" 可以被拆分成 "leet code"。
示例 2:
输入: s = "applepenapple", wordDict = ["apple", "pen"] 输出: true 解释: 返回 true 因为 "applepenapple" 可以被拆分成 "apple pen apple"。 注意你可以重复使用字典中的单词。
3. 解析
维护一个数组用来表示截取到第n位的字符串时,这个字符串是否能由字典中的单词组成
对于每一个字符串如果能被拆分一定能表示为
- 一个已经确定能被拆分的小字符串
- 一个字典中的单词
这两者组成
4. 解答
public boolean wordBreak(String s, List<String> wordDict) { boolean[] data = new boolean[s.length()] ; if(wordDict.contains(s.substring(0, 1))){ data[0] = true ; } for(int i = 1 ; i < s.length() ; i++){ for(int j = i ; j >= 0 ; j--){ if((j-1 < 0 || data[j-1]) && wordDict.contains(s.substring(j, i+1))){ data[i] = true ; } } } return data[s.length()-1] ; }
-
五、单词拆分 II
1. 题目
给定一个非空字符串 s 和一个包含非空单词列表的字典 wordDict,在字符串中增加空格来构建一个句子,使得句子中所有的单词都在词典中。返回所有这些可能的句子。
说明:
- 分隔时可以重复使用字典中的单词。
- 你可以假设字典中没有重复的单词。
2. 输入输出
示例 1:
输入: s = "catsanddog" wordDict = ["cat", "cats", "and", "sand", "dog"] 输出: [ "cats and dog", "cat sand dog" ]
示例 2:
输入: s = "pineapplepenapple" wordDict = ["apple", "pen", "applepen", "pine", "pineapple"] 输出: [ "pine apple pen apple", "pineapple pen apple", "pine applepen apple" ] 解释: 注意你可以重复使用字典中的单词。
3. 解析
这一题其实与上一题的题干基本相同,只不过需要我们表示出所有的拆分情况
我们知道一个字符串如果能够被拆分它一定是以一个字典中的单词为开头
所以我们遍历字典查找字符串的开头
剩余的部分继续拆解然后与开头进行自由组合
我们同时维护一个已经拆解的字符串表map
避免重复拆解浪费时间
4. 解答
private Map<String, List<String>> map = new HashMap<>(); public List<String> wordBreak(String s, List<String> wordDict) { if (map.containsKey(s)) //如果包含 则直接返回s return map.get(s); List<String> list = new ArrayList<>(); if (s.length() == 0) { list.add(""); return list; } for (String word : wordDict) { if (s.startsWith(word)) {//判断s是否含有word的前缀 List<String> tmpList = wordBreak(s.substring(word.length()), wordDict); for (String tmp : tmpList) list.add(word + (tmp.equals("") ? "" : " ") + tmp);//空的话则""结尾 } } map.put(s, list);//记录可以拆分的字符串,并且记录拆分的方法 return list; }
-
六、戳气球
1. 题目
有
n
个气球,编号为0
到n-1
,每个气球上都标有一个数字,这些数字存在数组nums
中。现在要求你戳破所有的气球。每当你戳破一个气球
i
时,你可以获得nums[left] * nums[i] * nums[right]
个硬币。 这里的left
和right
代表和i
相邻的两个气球的序号。注意当你戳破了气球i
后,气球left
和气球right
就变成了相邻的气球。求所能获得硬币的最大数量。
说明:
- 你可以假设
nums[-1] = nums[n] = 1
,但注意它们不是真实存在的所以并不能被戳破。 - 0 ≤
n
≤ 500, 0 ≤nums[i]
≤ 100
2. 输入输出
示例:
输入: [3,1,5,8] 输出: 167 解释: nums = [3,1,5,8] --> [3,5,8] --> [3,8] --> [8] --> [] coins = 3*1*5 + 3*5*8 + 1*3*8 + 1*8*1 = 167
3. 解析
官方解答 https://leetcode-cn.com/problems/burst-balloons/solution/chuo-qi-qiu-by-leetcode/
当我们正向思考时会发现,我们每一次戳破气球都会改变气球之间的相邻情况
从而影响到戳破下一个气球的硬币价值计算
所以我们不如反向思考一下,将题目要求的戳破气球改为放置气球
从一开始没有气球,到之后每放置一个气球我们就计算戳破这个气球的硬币价值
由于气球是不断添加的,我们添加一个气球并不影响上一个添加气球的硬币计算情况
因为在上一个添加气球看来我们后添加的气球是先被戳破的
戳破之后上一个气球所处的环境就和它被放置时的环境一模一样
我们可以得到放置气球的计算公式为
dp【i】【j】 = max { dp【i】【j】,dp【i】【k】 + dp【k】【j】 + coins【i】* coins【k】* coins【j】}
表示最优硬币价值为两侧子问题的最优解加上戳破当前气球得到的硬币价值
4. 解答
我们首先根据添加气球的思路可以得到一个递归的解法
每次添加一个气球之后,该气球的两侧就被分割成为两个子问题
当递归到没有气球只有边界时,返回硬币的价值为0
public int maxCoins(int[] nums){ int n = nums.length ; //扩展数组并填充边界两侧为1 //虚拟边界不能被戳破,只是便于我们计算硬币价值 int length = n + 2 ; int[] coins = new int[length] ; coins[0] = coins[length-1] = 1 ; for(int i = 0 ; i < n ; i++){ coins[i+1] = nums[i] ; } //此处可以维护一个记忆数组进行优化 //用于记录我们已经递归得到的子问题的最优解法 //避免重复计算浪费时间 int[][] memo = new int[length][length] ; return calCoins(coins, 0, length-1, memo) ; } public int calCoins(int[] nums, int m, int n, int[][] memo){ if(n - m == 1){ return 0 ; } if(memo[m][n] != 0){ return memo[m][n] ; } int max = 0 ; for(int k = m+1 ; k < n ; k++){ int tmp = calCoins(nums, m, k, memo) + calCoins(nums, k, n, memo) + nums[m]*nums[n]*nums[k] ; if(tmp > max){ max = tmp ; } } memo[m][n] = max ; return max ; }
以上的解法复杂度如下
时间复杂度:O(N^3)
空间复杂度:O(N^2)
我们还可以采用自底向上的方式来填充memo数组
这样的解法我觉得不是很直观,但是好处就是不需要使用递归
只需要通过两层循环就可以实现
public int maxCoins(int[] nums){ int n = nums.length ; //填充数组,扩展边界 int length = n + 2 ; int[] coins = new int[length] ; coins[0] = coins[length-1] = 1 ; for(int i = 0 ; i < n ; i++){ coins[i+1] = nums[i] ; } int[][] dp = new int[length][length] ; //begin //左边界,从右向左递减,最大最小取值包括虚拟边界 for(int i = length-1 ; i >= 0 ; i--){ //end //右边界,从左边界向右递增,最大最小取值包括虚拟边界 for(int j = i+1 ; j < length ; j++){ //k //戳破气球选择(或者理解为放置气球选择) //从左边界到右边界之间,不包括边界,递增,最大最小取值不包括虚拟边界 for(int k = i+1 ; k < j ; k++){ dp[i][j] = Math.max(dp[i][j], dp[i][k] + dp[k][j] + coins[i]*coins[j]*coins[k]) ; } } } return dp[0][length-1] ; }
本题我认为是动态规划系列中最难理解的一题
主要难点在于将戳破气球转变为放置气球的认识
在我们常见的动态规划题目中,上一个动作和下一个动作之间通常可以列出明确的关系式
所以我们可以先计算当前动作下的最优解,然后将其应用到下一个动作的解法中
但是戳气球如果从正向理解的话很显然不符合这个条件
无论怎么戳气球都有可能改变下一动作的计算方式
而反向理解的话它就是一道常规的动态规划题目
不过二维数组的填充顺序也是有一点难度的
总而言之动态规划也是一种利用子问题来帮助快速计算的算法方式
重点就是找出子问题和下一个动作之间的动态表达式
这么一想动态规划和递归真的好像XD
- 只不过递归是自顶向下划分子问题
- 动态规划是自底向上的计算方法(但是动规的方式更有利于子问题结果的重复利用
- 你可以假设
-
Leetcode 滑动窗口算法
一、最小覆盖子串
1. 题目
给你一个字符串 S、一个字符串 T,请在字符串 S 里面找出:包含 T 所有字符的最小子串。
2. 输入输出
示例:
输入: S = "ADOBECODEBANC", T = "ABC" 输出: "BANC"
3. 解析
对于寻找子串的问题我们都可以使用滑动窗口算法来解决
滑动窗口算法的本质就是双指针,左指针和右指针之间的字符串就是我们的子串
算法首先初始化左指针和右指针全都指向最左侧
- 然后向右移动右指针直到找到满足条件的子串
- 接着向右移动左指针直到子串不在满足要求
- 重复以上操作直到指针移动到边界或者不再满足移动的条件为止
4. 解答
class Solution { public String minWindow(String s, String t) { int sLen = s.length() ; int tLen = t.length() ; if(sLen == 0 || tLen == 0 || sLen < tLen){ return "" ; } //首先初始化两个数组用于记录字符串字符出现次数 //由于最大的字符Z的ASCII码不会超过128 //故我们统一开辟数组大小为128 //PS:通用的方法是使用HashMap,数组方便的话也可以使用数组替代 int[] winFreq = new int[128] ; int[] tFreq = new int[128] ; //把字符串转化为字符数组方便读取字符 char[] charArrayS = s.toCharArray() ; char[] charArrayT = t.toCharArray() ; //统计子串的各字符出现次数 for(char c : charArrayT){ tFreq[c]++ ; } //distance用于记录总的满足条件字符数 int distance = 0 ; int left = 0 ; int right = 0 ; int minLen = sLen + 1 ; int begin = 0 ; while(right < sLen){ char rightChar = charArrayS[right] ; // 无关的字符直接向右移动指针跳过 if(tFreq[rightChar] == 0){ right++ ; continue ; } // 如果需要的字符出现次数还没有满足条件 // 则继续收集字符 if(winFreq[rightChar] < tFreq[rightChar]){ distance++ ; } winFreq[rightChar]++ ; right++ ; // 当子串满足条件时 while(distance == tLen){ // 比较最小子串 if(right - left < minLen){ minLen = right - left ; begin = left ; } int charLeft = charArrayS[left] ; // 无关的字符直接向右移动指针跳过 if(tFreq[charLeft] == 0){ left++ ; continue ; } // 如果左侧的字符去掉之后会影响所需要的子串字符个数 // 使得子串不再满足条件 if(winFreq[charLeft] == tFreq[charLeft]){ distance-- ; } winFreq[charLeft]-- ; left++ ; } } // 没有找到符合条件的子串 if(minLen == sLen + 1){ return "" ; } // 截取最小子串 return s.substring(begin, begin + minLen) ; } }
以上的解法我们是以加法的思想
即收集我们所需要的子串字符直到满足条件
这种方法需要多开辟一个数组来统计当前已经收集的子串字符个数来进行比较
我们可以转换一个减法的思路来降低算法的空间复杂度
即当遇到需要的子串字符时我们减少所需要的子串字符上限
当所需要的字符数降为0时即获得了满足要求的子串
public String minWindow(String s, String t) { int sLen = s.length() ; int tLen = t.length() ; if(sLen == 0 || tLen == 0 || sLen < tLen){ return "" ; } // int[] winFreq = new int[128] ; int[] tFreq = new int[128] ; char[] charArrayS = s.toCharArray() ; char[] charArrayT = t.toCharArray() ; for(char c : charArrayT){ tFreq[c]++ ; } // 初始需要的总字符个数为字符串的长度 int distance = tLen ; int left = 0 ; int right = 0 ; int minLen = sLen + 1 ; int begin = 0 ; while(right < sLen){ char rightChar = charArrayS[right] ; // 对于无关的字符实际上可以不做处理,并不会影响最后的结果 // 读者可以仔细分析 // if(tFreq[rightChar] == 0){ // right++ ; // continue ; // } // 如果当前字符是我们所需要的字符 // 则我们还需要收集的字符数减少 if(tFreq[rightChar] > 0){ distance-- ; } // 当前字符需要收集的个数也减少 tFreq[rightChar]-- ; right++ ; // 当需要收集的字符数为0表示找到了子串 while(distance == 0){ // 记录最小子串 if(right - left < minLen){ minLen = right - left ; begin = left ; } int charLeft = charArrayS[left] ; // 对于无关的字符实际上可以不做处理,并不会影响最后的结果 // 读者可以仔细分析 // if(tFreq[charLeft] == 0){ // left++ ; // continue ; // } // 把左侧的字符剔除出窗口时 // 恢复所需要的对该字符的一次收集次数 // PS:我们可以观察当这个字符是无关字符时,由于我们之前没有对其进行特殊处理 // 它的引用次数会被减为-1,只有这个字符是需要的才会在前面的操作中被减为0 if(tFreq[charLeft] == 0){ distance++ ; } tFreq[charLeft]++ ; left++ ; } } if(minLen == sLen + 1){ return "" ; } return s.substring(begin, begin + minLen) ; }
-
LeetCode 数组和字符串
一、缺失的第一个正数
1. 题目
给你一个未排序的整数数组,请你找出其中没有出现的最小的正整数。
2. 输入输出
示例 1:
输入: [1,2,0] 输出: 3
示例 2:
输入: [3,4,-1,1] 输出: 2
示例 3:
输入: [7,8,9,11,12] 输出: 1
3. 解析
官方解答 https://leetcode-cn.com/problems/first-missing-positive/solution/que-shi-de-di-yi-ge-zheng-shu-by-leetcode-solution/
首先我们最直接的思路就是从1开始试验最小的未出现正整数
- 如果数组中包含这个数那么就递增试验
- 如果数组中没有包含这个数那么就直接返回
在这种思路下我们可以选择ArrayList来存储数字
那么直接使用contains方法就可以判断数字是否存在
但是这种方法的空间复杂度太高
改进之后我们也可以在原数组的基础上进行操作
我们仔细分析之后可以发现这个最小未出现正整数只与不大于数组大小的数字有关
- 假设数组大小为N,试想一下如果数组中的数字就是1~N,则输出结果为N+1
- 如果数组中包含大于N的数字,那么输出结果一定在1~N中
所以我们只需要记录不大于N的数字的出现信息就可以得到结果
此时利用原数组直接把数字与数组下标对应就可以实现原地记录
我们可以有以下方法
- 首先遍历将数组中所有负数置为N+1,表示无关项
- 然后遍历将数组中出现的1~N的数字,把数字对应数组下标中的数字置为负数
表示该下标对应的数字出现过
- 最后遍历数组,如果数字为正数表示该处下标的数字未出现过
另外还有一种速度更快,且只使用极小额外空间的解法
即使用BitSet来记录数字是否出现的解法
因为只需要记录的数字大小只与数组的大小有关
所以我们可以根据数组的大小来开辟空间
记录数字之后只需要从1开始遍历验证就可以了
4. 解答
原地使用数组的解法
int n = nums.length ; for(int i = 0 ; i < n ; i++){ if(nums[i] <= 0){ nums[i] = n+1 ; } } for(int i = 0 ; i < n ; i++){ int real = Math.abs(nums[i]) ; if(real >= 1 && real <= n){ int index = real - 1 ; nums[index] = (-1) * Math.abs(nums[index]) ; } System.out.println(nums[i]) ; } for(int i = 0 ; i < n ; i++){ if(nums[i] > 0){ return i+1 ; } } return n+1 ;
使用BitSet的解法
int n = nums.length ; //能容纳N位数字 int[] bitSet = new int[(n - 1) / 32 + 1] ; for(int num : nums){ if(num >= 1 && num <= n){ //记录1~N的数字是否出现 int index = (num - 1) / 32 ; //出现数字则将某一位置为1 bitSet[index] |= (1 << ((num - 1) % 32)) ; } } for(int i = 0 ; i < n ; i++){ //如果与的结果为0说明对应记录位置没有被置为1,该数字没有出现过 if((bitSet[i / 32] & (1 << (i % 32))) == 0){ return i + 1 ; } } return n + 1 ;
-
KNN算法
1. 简单概括
通过计算数据集的数据和输入数据之间的欧氏距离
将结果按从小到大的顺序排序
选出前k个欧式距离最小的数据集数据
这k个数据中出现次数最多的标签的即为估计结果
2. 核心代码
def classify(inX, dataSet, labels, k): # inx 输入数据 # dataSet 数据集 # labels 标签 # k 选取前k个数据 # dataSetSize 数据集数据数量 dataSetSize = dataSet.shape[0] # np.tile 用于复制扩展inX使之大小与dataSet相同以便于后面同时对整个数据集数据进行计算操作 diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet # 作差后取平方 sqDiffMat = diffMat**2 # 把同一行的数据相加 sumMat = sqDiffMat.sum(1) # 对所有数据开方 distance = sumMat**0.5 # 返回当前数组索引在数组排序之后的顺序 # [3, 6, 4, 2] 索引 [0, 1, 2, 3] # 返回结果 [3, 0, 2, 1] sortedDistance = distance.argsort() classCount = {} for i in range(k): # 前k个标签 label = labels[sortedDistance[i]] # dict.get(label, 0) 如果字典中没有当前索引会用0作为初始值,不写0默认为None classCount[label] = classCount.get(label, 0) + 1 # 对字典进行排序,排序依据key是字典中的第二个维度,lambda x: x[1] # 也可以使用 operator.itemgetter(1) 也表示第二个维度 result = sorted(classCount.items(), key=lambda x: x[1], reverse=True) # result = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return result[0][0]
-
决策树算法
1. 简单概括
利用熵将整个数据集进行分割
使得分割后数据集的熵最小
不断对子数据集进行递归
直至无法进一步分割或者子数据集里数据的标签都一致时递归结束
分割的过程会形成一棵决策树
利用决策树将输入的数据归类到某一分割后的数据集中
子数据集所带有的标签就是决策的结果
2. 熵的定义和实现
熵定义为信息的期望值
-
信息的计算公式为
$$
l(x_i) = -log_2p(x_i) \quad 其中 \ p(x_i) \ 是选择某一分类的概率
$$ -
信息的期望值为
-
核心代码
def calcShannon(dataSet): dataSetSize = len(dataSet) result = {} # 利用字典统计每一种标签的数据数量 for data in dataSet: label = data[-1] result[label] = result.get(label, 0) + 1 shannon = 0 # 对于每一种分类计算其概率并统计信息熵 for key in result: prob = result[key]/dataSetSize shannon -= prob * math.log(prob, 2) return shannon
3. 分割数据集
# 分割数据集,将维度axis且值为value的数据单独提取出来 def splitDataSet(dataSet, axis, value): result = [] for data in dataSet: if data[axis] == value: # 使用axis分割后将axis这一列从数据中去掉 tmp = data[:axis] # 这里通过entend拼接数组跳过了axis这一列 tmp.extend(data[axis+1:]) result.append(tmp) return result
4. 选择最好的分割维度
def chooseBestSplitAxis(dataSet): numOfAxis = len(dataSet[0]) - 1 baseEntropy = calcShannon(dataSet) bestInfoGain = 0.0 bestAxis = -1 # 对所有维度都循环试验 for axis in range(numOfAxis): allValue = [x[axis] for x in dataSet] # 维度下所有可能的不重复值 valueSet = set(allValue) newEntropy = 0.0 for value in valueSet: # 按不重复的值分割成子数据集 subDataSet = splitDataSet(dataSet, axis, value) prob = float(len(subDataSet)) / float(len(dataSet)) # 新的熵为部分熵按比例求和 newEntropy += prob * calcShannon(subDataSet) # 作差为正说明新熵比旧熵小,混乱程度减小 # newInfoGain = baseEntropy - newEntropy # if newInfoGain > bestInfoGain: # bestAxis = axis # bestInfoGain = newInfoGain if newEntropy < baseEntropy: # 上面源码有点绕 简单理解就是熵变小了就选择 bestAxis = axis return bestAxis
5. 生成决策树
def createTree(dataSet, labels): labelList = [data[-1] for data in dataSet] # 递归出口一:子数据集的标签已经统一只有一种,不需要再进一步分割 if labelList.count(labelList[0]) == len(labelList): return labelList[0] # 递归出口二:子数据集已经没有了可分割的维度只剩下了标签 if len(dataSet[0]) == 1: # 统计子数据集中出现次数最多的标签即为决策结果 return voteMaxLabel(labelList) # 选择熵最小的分割维度 bestAxis = chooseBestSplitAxis(dataSet) bestLabel = labels[bestAxis] # 建立决策树字典 myTree = {bestLabel: {}} # 删除已用于分割的维度对应的标签 del(labels[bestAxis]) allValue = [x[bestAxis] for x in dataSet] valueSet = set(allValue) for value in valueSet: # 复制一遍标签 subLabel = labels[:] # 采用最好的分割方法分割数据集后递归生成子树 myTree[bestLabel][value] = createTree(splitDataSet(dataSet, bestAxis, value), subLabel) return myTree
6. 利用决策树进行决策
-
决策树生成范例
{'flippers': {0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}}}
其中每一次决策需要用到决策树的两层
以这里的决策树为例
第一层‘flippers’为进行决策的标签
第二层的0和1为在该标签下进行决策的不同选择
-
核心代码
def classifyByTree(tree, labels, data): # 得到进行决策的标签 firstLabel = list(tree.keys())[0] # 用该标签进行决策的子树 secondDict = tree[firstLabel] # 得到用于决策的标签所属的维度,用于后面取出数据在该维度的值 firstLabelIndex = labels.index(firstLabel) classLabel = 'Error' # 对于该标签下进行决策的不同的值 for value in secondDict.keys(): # 如果数据在该决策标签维度下的值等于子树的决策值 if data[firstLabelIndex] == value: # 如果子树下还有子树(即还是一个字典类型的数据)则继续进行决策 if type(secondDict[value]).__name__ == 'dict': classLabel = classifyByTree(secondDict[value], labels, data) else: # 否则子树下的值就是决策的结果 classLabel = secondDict[value] return classLabel
-