楼主: Lisrelchen
1522 10

Milk:Machine Learning Toolkit in Python [推广有奖]

11
Lisrelchen 发表于 2016-4-25 03:41:20
  1. # -*- coding: utf-8 -*-
  2. # Copyright (C) 2010-2011, Luis Pedro Coelho <luis@luispedro.org>
  3. # vim: set ts=4 sts=4 sw=4 expandtab smartindent:
  4. #
  5. # License: MIT. See COPYING.MIT file in the milk distribution

  6. '''
  7. Random Forest
  8. -------------
  9. Main elements
  10. -------------
  11. rf_learner : A learner object
  12. '''

  13. from __future__ import division
  14. import numpy as np
  15. import milk.supervised.tree
  16. from .normalise import normaliselabels
  17. from .base import supervised_model
  18. from ..utils import get_nprandom

  19. __all__ = [
  20.     'rf_learner',
  21.     ]

  22. def _sample(features, labels, n, R):
  23.     '''
  24.     features', labels' = _sample(features, labels, n, R)
  25.     Sample n element from (features,labels)
  26.     Parameters
  27.     ----------
  28.     features : sequence
  29.     labels : sequence
  30.         Same size as labels
  31.     n : integer
  32.     R : random object
  33.     Returns
  34.     -------
  35.     features' : sequence
  36.     labels' : sequence
  37.     '''

  38.     N = len(features)
  39.     sfeatures = []
  40.     slabels = []
  41.     for i in range(n):
  42.         idx = R.randint(N)
  43.         sfeatures.append(features[idx])
  44.         slabels.append(labels[idx])
  45.     return np.array(sfeatures), np.array(slabels)

  46. class rf_model(supervised_model):
  47.     def __init__(self, forest, names, return_label = True):
  48.         self.forest = forest
  49.         self.names = names
  50.         self.return_label = return_label

  51.     def apply(self, features):
  52.         rf = len(self.forest)
  53.         votes = sum(t.apply(features) for t in self.forest)
  54.         if self.return_label:
  55.             return (votes > (rf//2))
  56.         return votes / rf


  57. class rf_learner(object):
  58.     '''
  59.     Random Forest Learner
  60.     learner = rf_learner(rf=101, frac=.7)
  61.     Attributes
  62.     ----------
  63.     rf : integer, optional
  64.         Nr of trees to learn (default: 101)
  65.     frac : float, optional
  66.         Sample fraction
  67.     R : np.random object
  68.         Source of randomness
  69.     '''
  70.     def __init__(self, rf=101, frac=.7, R=None):
  71.         self.rf = rf
  72.         self.frac = frac
  73.         self.R = get_nprandom(R)

  74.     def train(self, features, labels, normalisedlabels=False, names=None, return_label=True, **kwargs):
  75.         N,M = features.shape
  76.         m = int(self.frac*M)
  77.         n = int(self.frac*N)
  78.         R = get_nprandom(kwargs.get('R', self.R))
  79.         tree = milk.supervised.tree.tree_learner(return_label=return_label)
  80.         forest = []
  81.         if not normalisedlabels:
  82.             labels,names = normaliselabels(labels)
  83.         elif names is None:
  84.             names = (0,1)
  85.         for i in range(self.rf):
  86.             forest.append(
  87.                     tree.train(*_sample(features, labels, n, R),
  88.                                **{'normalisedlabels' : True})) # This syntax is necessary for Python 2.5
  89.         return rf_model(forest, names, return_label)
复制代码

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jltj
拉您入交流群
GMT+8, 2025-12-27 08:15