楼主: lion003
14887 50

Scala for Machine Learning [推广有奖]

31
Lisrelchen(真实交易用户) 发表于 2016-4-20 09:45:29
  1. package org.scalaml.supervised.hmm
  2. import org.scalaml.core.Matrix
  3. import org.scalaml.core.Types.ScalaMl._
  4. import HMMConfig._


  5. final protected class Alpha(lambda: HMMLambda, obs: Array[Int]) extends Pass(lambda, obs) {

  6.         val alpha: Double = {
  7.                 alphaBeta = lambda.initAlpha(obs)
  8.                 normalize(0)
  9.                 sumUp
  10.         }

  11.   

  12.         def logProb: Double = foldLeft(lambda.getT, (s, t) => s + Math.log(ct(t)), Math.log(alpha))


  13.         private def sumUp: Double = {         
  14.                 foreach(1, lambda.getT, t => {
  15.                         updateAlpha(t)                // Implements first equation of M3
  16.                         normalize(t)                        // Normalized wit the sum of alpha(i), t 0 -> N-1
  17.                 })
  18.                 foldLeft(lambda.getN, (s, k) => s + alphaBeta(lambda.d_1, k))
  19.         }


  20.         private def updateAlpha(t: Int): Unit =
  21.                 foreach(lambda.getN, i => {
  22.                         val newValue = lambda.alpha(alphaBeta(t-1, i), i, obs(t))
  23.                         alphaBeta += (t, i, lambda.alpha(alphaBeta(t-1, i), i, obs(t)))
  24.                 })
  25. }


  26. object Alpha {

  27.         def apply(lambda: HMMLambda, obs: Array[Int]): Alpha = new Alpha(lambda,obs)
  28. }


  29. // --------------------------------  EOF -------------------------------------
复制代码

32
Lisrelchen(真实交易用户) 发表于 2016-4-20 09:48:45
  1. package org.scalaml.supervised.hmm

  2. import scala.util.{Try, Success, Failure}
  3. import org.apache.log4j.Logger
  4. import org.scalaml.util.DisplayUtils

  5. final protected class BaumWelchEM(
  6.                 config: HMMConfig,
  7.                 obs: Array[Int],
  8.                 numIters: Int,
  9.                 eps: Double)        extends HMMModel(HMMLambda(config), obs) {
  10.         import BaumWelchEM._
  11.        
  12.         check(config, obs, numIters, eps)
  13.         private val logger = Logger.getLogger("BaumWelchEM")

  14.         val state = HMMState(lambda, numIters)

  15.         val maxLikelihood: Option[Double] = {
  16.                 Try {
  17.                         var likelihood = frwrdBckwrdLattice
  18.                
  19.                                 // Apply the alpha / beta algorithm for each state
  20.                                 // until the likelihood value Alpha.alpha converges toward a maximum
  21.                         Range(0, state.maxIters) find( _ => {
  22.                                 lambda.estimate(state, obs)
  23.                                 val _likelihood = frwrdBckwrdLattice
  24.                                 val diff = likelihood - _likelihood
  25.                                 likelihood = _likelihood
  26.                             
  27.                                 diff < eps
  28.                         })
  29.                         match {
  30.                                 case Some(index) => likelihood
  31.                                 case None => throw new IllegalStateException("BaumWelchEM.maxLikelihood failed")
  32.                         }
  33.                 }
  34.                 match {
  35.                                 // If the maximum likelihood is computed, normalize the resulting Lambda model
  36.                         case Success(likelihood) => {
  37.                                 state.lambda.normalize
  38.                                 Some(likelihood)
  39.                         }
  40.                         case Failure(e) => DisplayUtils.none("BaumWelchEM.maxLikelihood", logger, e)
  41.                 }
  42.         }
  43.    
  44.                 /*

  45.         private def frwrdBckwrdLattice: Double  = {
  46.                         // Compute the forward pass given the sequence of observations obs
  47.                 val _alpha = Alpha(lambda, obs)
  48.                
  49.                         // Compute the probabilities of a state given the
  50.                 state.update(_alpha.getAlphaBeta, Beta(lambda, obs).getAlphaBeta, lambda.A, lambda.B, obs)
  51.                         // Finally returns the likelihood
  52.                 _alpha.alpha
  53.         }
  54. }



  55. object BaumWelchEM {
  56.         private val EPS = 1e-3   

  57.         def apply(config: HMMConfig, labels: Array[Int], numIters: Int, eps: Double): BaumWelchEM =
  58.                 new BaumWelchEM(config, labels, numIters,eps)
  59.        
  60.         private val EPS_LIMITS = (1e-8, 0.1)
  61.         private val MAX_NUM_ITERS = 1024
  62.        
  63.         private def check(config: HMMConfig, obs: Array[Int], numIters: Int, eps: Double): Unit = {
  64.                 require( !obs.isEmpty, "BaumWelchEM.check Observations are undefined")
  65.                 require(numIters > 1 && numIters < MAX_NUM_ITERS,
  66.                                 s"BaumWelchEM.check Maximum number of iterations $numIters is out of range")
  67.                 require(eps > EPS_LIMITS._1 && eps < EPS_LIMITS._2,
  68.                                 s"BaumWelchEM.check Convergence criteria for HMM Baum_Welch $eps is out of range")
  69.         }

  70. }
  71. // -----------------------------  EOF --------------------------------
复制代码

33
Lisrelchen(真实交易用户) 发表于 2016-4-20 09:52:01
  1. package org.scalaml.supervised.hmm

  2. import scala.util.{Try, Success, Failure}
  3. import org.apache.log4j.Logger

  4. import org.scalaml.core.Matrix
  5. import org.scalaml.util.DisplayUtils
  6. import HMMConfig._

  7. protected class Beta(lambda: HMMLambda, obs: Array[Int]) extends Pass(lambda, obs) {
  8.         private val logger = Logger.getLogger("Beta")


  9.         val complete = {
  10.                 Try {
  11.                                 // Creates the matrix of probabilities of a state given the
  12.                                 // observations, and initialize the probability for the last observation
  13.                                 // (index T-1) as 1.0
  14.                         alphaBeta = Matrix[Double](lambda.getT, lambda.getN)       
  15.                         alphaBeta += (lambda.d_1, 1.0)
  16.                                 // Normalize by computing (ct)
  17.                         normalize(lambda.d_1)
  18.                                 // Compute the beta probabilites for all the observations.
  19.                         sumUp
  20.                 }
  21.                 match {
  22.                         case Success(t) => true
  23.                         case Failure(e) => DisplayUtils.error("Beta.complete failed", logger, e); false
  24.                 }
  25.         }
  26.        

  27.         private def sumUp: Unit = {
  28.                         // Update and normalize the beta probabilities for all
  29.                         // the observations starting with index T-2.. befor normalization.
  30.                 (lambda.getT-2 to 0 by -1).foreach( t =>{
  31.                         updateBeta(t)
  32.                         normalize(t)
  33.                 })
  34.         }
  35.        

  36.         private def updateBeta(t: Int): Unit =
  37.                 foreach(lambda.getN, i =>
  38.                                 alphaBeta += (t, i, lambda.beta(alphaBeta(t+1, i), i, obs(t+1))))
  39. }



  40. object Beta {

  41.         def apply(lambda: HMMLambda,  obs: Array[Int]): Beta = new Beta(lambda, obs)
  42. }


  43. // --------------------------------  EOF -------------------------------------
复制代码

34
Lisrelchen(真实交易用户) 发表于 2016-4-20 09:57:49
  1. package org.scalaml.supervised.hmm

  2. import org.scalaml.core.Types.ScalaMl._
  3. import org.scalaml.core.Design.{PipeOperator, Model}
  4. import org.scalaml.core.XTSeries
  5. import org.scalaml.core.Matrix
  6. import scala.util.{Try, Success, Failure}
  7. import scala.annotation.implicitNotFound
  8. import org.apache.log4j.Logger
  9. import org.scalaml.util.DisplayUtils
  10. import HMM._

  11. object HMMForm extends Enumeration {
  12.         type HMMForm = Value
  13.         val EVALUATION, DECODING = Value
  14. }


  15. import HMMForm._

  16. abstract class HMMModel(val lambda: HMMLambda, val obs: Array[Int]) extends Model {
  17.         import HMMModel._
  18.        
  19.         check(obs)
  20. }


  21. object HMMModel {
  22.         private def check(obs: Array[Int]): Unit =
  23.                 require(!obs.isEmpty, "HMMModel.check Cannot create a model with undefined observations")
  24. }


  25. protected class Pass(lambda: HMMLambda, obs: Array[Int]) extends HMMModel(lambda, obs) {
  26.         protected var alphaBeta: Matrix[Double] = _
  27.         protected val ct = Array.fill(lambda.getT)(0.0)

  28.                 /**
  29.                  * Compute and apply the normalization factor ct for the computation of Alpha
  30.                  * [Formula M3] and Beta probabilities [Formula M7] for the observation at index t
  31.                  * @param t Index of the observation.
  32.                  */
  33.         protected def normalize(t: Int): Unit = {
  34.                 import HMMConfig._
  35.                 require(t >= 0 && t < lambda.getT, s"HMMModel.normalize Incorrect observation index t= $t")
  36.                
  37.                 ct.update(t, foldLeft(lambda.getN, (s, n) => s + alphaBeta(t, n)))
  38.                 alphaBeta /= (t, ct(t))
  39.         }

  40.         def getAlphaBeta: Matrix[Double] = alphaBeta
  41. }

  42. @implicitNotFound("HMM Conversion from DblVector to type T undefined")
  43. final protected class HMM[@specialized T <% Array[Int]](
  44.                 lambda: HMMLambda,
  45.                 form: HMMForm,
  46.                 maxIters: Int)
  47.                 (implicit f: DblVector => T)        extends PipeOperator[DblVector, HMMPredictor] {
  48.        
  49.         check(maxIters)
  50.        
  51.         private val logger = Logger.getLogger("HMM")
  52.         private[this] val state = HMMState(lambda, maxIters)
  53.        

  54.         override def |> : PartialFunction[DblVector, HMMPredictor] = {
  55.                 case obs: DblVector if( !obs.isEmpty) => {
  56.                         Try {
  57.                                 form match {
  58.                                         case EVALUATION => evaluate(obs)
  59.                                         case DECODING => decode(obs)
  60.                                 }
  61.                         } match {
  62.                                 case Success(prediction) => prediction
  63.                                 case Failure(e) =>
  64.                                         DisplayUtils.error("HMM.|> ", logger, e)
  65.                                         nullHMMPredictor
  66.                         }
  67.                 }
  68.         }
  69.        

  70.         def decode(obs: T): HMMPredictor = (ViterbiPath(lambda, obs).maxDelta, state.QStar())
  71.        

  72.         def evaluate(obs: T): HMMPredictor = (-Alpha(lambda, obs).logProb, obs)
  73.        

  74.         final def getModel: HMMLambda = lambda
  75. }


  76. object HMM {

  77.         type HMMPredictor = (Double, Array[Int])
  78.         val nullHMMPredictor = (-1.0, Array.empty[Int])
  79.        
  80.         def apply[T <% Array[Int]](
  81.                         lambda: HMMLambda,
  82.                         form: HMMForm,
  83.                         maxIters: Int)
  84.                         (implicit f: DblVector => T): HMM[T] =        new HMM[T](lambda, form, maxIters)

  85.         def apply[T <% Array[Int]](
  86.                         lambda: HMMLambda,
  87.                         form: HMMForm)
  88.                         (implicit f: DblVector => T): HMM[T] =  new HMM[T](lambda, form, HMMState.DEFAULT_MAXITERS)
  89.        
  90.        
  91.         def apply[T <% Array[Int]](
  92.                         config: HMMConfig,
  93.                         obsIndx: Array[Int],
  94.                         form: HMMForm,  
  95.                         maxIters: Int,
  96.                         eps: Double)
  97.                         (implicit f: DblVector => T): Option[HMM[T]] = {
  98.           
  99.                 val baumWelchEM = new BaumWelchEM(config, obsIndx, maxIters, eps)
  100.                 baumWelchEM.maxLikelihood.map(_ => new HMM[T](baumWelchEM.lambda, form, maxIters))
  101.         }
  102.        
  103.         val MAX_NUM_ITERS = 1024
  104.         private def check(maxIters: Int): Unit = {
  105.                 require(maxIters > 1 && maxIters < MAX_NUM_ITERS,
  106.                     s"HMM.check  Maximum number of iterations to train a HMM $maxIters is out of bounds")
  107.         }
  108. }

  109. // ----------------------------------------  EOF ------------------------------------------------------------
复制代码

35
Nicolle(真实交易用户) 学生认证  发表于 2016-4-20 10:13:28
提示: 作者被禁止或删除 内容自动屏蔽

36
Lisrelchen(真实交易用户) 发表于 2016-4-22 08:09:34

37
龙中龙(未真实交易用户) 发表于 2016-5-3 06:55:30
看一看,谢谢楼主

38
sacromento(未真实交易用户) 学生认证  发表于 2016-5-31 02:34:11 来自手机
学习了,谢谢

39
redalert99(真实交易用户) 发表于 2016-7-16 15:24:10
刚刚学习scala
  1. object Hello {
  2.   def main (args:Array[String]):Unit={
  3.     println("Hello Scala!")
  4.   }
  5. }
复制代码

40
kpochee(未真实交易用户) 发表于 2016-8-2 15:06:40
谢谢分享

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
jg-xs1
拉您进交流群
GMT+8, 2026-1-1 08:37