- 阅读权限
- 255
- 威望
- 1 级
- 论坛币
- 49655 个
- 通用积分
- 55.9937
- 学术水平
- 370 点
- 热心指数
- 273 点
- 信用等级
- 335 点
- 经验
- 57805 点
- 帖子
- 4005
- 精华
- 21
- 在线时间
- 582 小时
- 注册时间
- 2005-5-8
- 最后登录
- 2023-11-26
|
Decision Tree in Java
- /**
- *
- */
- package dt;
- import java.util.*;
- class Examples {
- class Example {
- private Map<String, String> values;
- private boolean classifier;
-
- public Example(String[] attributeNames, String[] attributeValues,
- boolean classifier) {
- assert(attributeNames.length == attributeValues.length);
- values = new HashMap<String, String>();
- for ( int i = 0 ; i < attributeNames.length ; i++ ) {
- values.put(attributeNames[i], attributeValues[i]);
- }
-
- this.classifier = classifier;
- }
- public Example(Map<String, String> attributes, boolean classifier) {
- this.classifier = classifier;
- this.values = attributes;
- }
- public Set<String> getAttributes() {
- return values.keySet();
- }
- public String getAttributeValue(String attribute) {
- return values.get(attribute);
- }
- public boolean matchesClass(boolean classifier) {
- return classifier == this.classifier;
- }
- }
- private List<Example> examples;
- public Examples() {
- examples = new LinkedList<Example>();
- }
-
- public void add(String[] attributeNames, String[] attributeValues,
- boolean classifier) {
- examples.add(new Example(attributeNames, attributeValues, classifier));
- }
- public void add(Map<String, String> attributes, boolean classifier) {
- examples.add(new Example(attributes, classifier));
- }
- /**
- * Returns the number of examples where the attribute has the specified
- * 'decision' value
- */
- int countDecisions(String attribute, String decision) {
- int count = 0;
- for ( Example e : examples ) {
- if ( e.getAttributeValue(attribute).equals(decision) )
- count++;
- }
- return count;
- }
- /**
- * Returns a map from each attribute name to a set of all values used in the
- * examples for that attribute.
- */
- public Map<String, Set<String> > extractDecisions() {
- Map<String, Set<String> > decisions = new HashMap<String, Set<String> >();
- for ( String attribute : extractAttributes() ) {
- decisions.put(attribute, extractDecisions(attribute));
- }
- return decisions;
- }
- public int countNegative(String attribute, String decision,
- Map<String, String> attributes) {
- return countClassifier(false, attribute, decision, attributes);
- }
-
- public int countPositive(String attribute, String decision,
- Map<String, String> attributes) {
- return countClassifier(true, attribute, decision, attributes);
- }
-
- public int countNegative(Map<String, String> attributes) {
- return countClassifier(false, attributes);
- }
-
- public int countPositive(Map<String, String> attributes) {
- return countClassifier(true, attributes);
- }
-
- public int count(String attribute, String decision, Map<String, String> attributes) {
- attributes = new HashMap(attributes);
- attributes.put(attribute, decision);
- return count(attributes);
- }
-
- public int count(Map<String, String> attributes) {
- int count = 0;
- nextExample:
- for ( Example e : examples ) {
- for ( Map.Entry<String, String> attribute : attributes.entrySet() )
- if ( !(e.getAttributeValue(attribute.getKey()).equals(attribute.getValue())) )
- continue nextExample;
- // All of the provided attributes match the example.
- count++;
- }
- return count;
- }
-
- public int countClassifier(boolean classifier, Map<String, String> attributes) {
- int count = 0;
- nextExample:
- for ( Example e : examples ) {
- for ( Map.Entry<String, String> attribute : attributes.entrySet() )
- if ( !(e.getAttributeValue(attribute.getKey()).equals(attribute.getValue())) )
- continue nextExample;
- // All of the provided attributes match the example.
- // If the example matches the classifier, then include it in the count.
- if ( e.matchesClass(classifier) )
- count++;
- }
- return count;
- }
- public int countClassifier(boolean classifier, String attribute,
- String decision, Map<String, String> attributes) {
- attributes = new HashMap(attributes);
- attributes.put(attribute, decision);
- return countClassifier(classifier, attributes);
- }
-
- /**
- * Returns the number of examples.
- */
- public int count() {
- return examples.size();
- }
- /**
- * Returns a set of attribute names used in the examples.
- */
- public Set<String> extractAttributes() {
- Set<String> attributes = new HashSet<String>();
- for ( Example e : examples ) {
- attributes.addAll(e.getAttributes());
- }
- return attributes;
- }
- private Set<String> extractDecisions(String attribute) {
- Set<String> decisions = new HashSet<String>();
- for ( Example e : examples ) {
- decisions.add(e.getAttributeValue(attribute));
- }
- return decisions;
- }
- }
- /**
- *
- */
- package dt;
- import java.util.*;
- public class DecisionTree {
- /**
- * Contains the set of available attributes.
- */
- private LinkedHashSet<String> attributes;
- /**
- * Maps a attribute name to a set of possible decisions for that attribute.
- */
- private Map<String, Set<String> > decisions;
- private boolean decisionsSpecified;
- /**
- * Contains the examples to be processed into a decision tree.
- *
- * The 'attributes' and 'decisions' member variables should be updated
- * prior to adding examples that refer to new attributes or decisions.
- */
- private Examples examples;
- /**
- * Indicates if the provided data has been processed into a decision tree.
- *
- * This value is initially false, and is reset any time additional data is
- * provided.
- */
- private boolean compiled;
- /**
- * Contains the top-most attribute of the decision tree.
- *
- * For a tree where the decision requires no attributes,
- * the rootAttribute yields a boolean classification.
- *
- */
- private Attribute rootAttribute;
- private Algorithm algorithm;
- public DecisionTree() {
- algorithm = null;
- examples = new Examples();
- attributes = new LinkedHashSet<String>();
- decisions = new HashMap<String, Set<String> >();
- decisionsSpecified = false;
- }
- private void setDefaultAlgorithm() {
- if ( algorithm == null )
- setAlgorithm(new ID3Algorithm(examples));
- }
- public void setAlgorithm(Algorithm algorithm) {
- this.algorithm = algorithm;
- }
- /**
- * Saves the array of attribute names in an insertion ordered set.
- *
- * The ordering of attribute names is used when addExamples is called to
- * determine which values correspond with which names.
- *
- */
- public DecisionTree setAttributes(String[] attributeNames) {
- compiled = false;
- decisions.clear();
- decisionsSpecified = false;
- attributes.clear();
- for ( int i = 0 ; i < attributeNames.length ; i++ )
- attributes.add(attributeNames[i]);
- return this;
- }
- /**
- */
- public DecisionTree setDecisions(String attributeName, String[] decisions) {
- if ( !attributes.contains(attributeName) ) {
- // TODO some kind of warning or something
- return this;
- }
- compiled = false;
- decisionsSpecified = true;
- Set<String> decisionsSet = new HashSet<String>();
- for ( int i = 0 ; i < decisions.length ; i++ )
- decisionsSet.add(decisions[i]);
- this.decisions.put(attributeName, decisionsSet);
- return this;
- }
- /**
- */
- public DecisionTree addExample(String[] attributeValues, boolean classification) throws UnknownDecisionException {
- String[] attributes = this.attributes.toArray(new String[0]);
- if ( decisionsSpecified )
- for ( int i = 0 ; i < attributeValues.length ; i++ )
- if ( !decisions.get(attributes[i]).contains(attributeValues[i]) ) {
- throw new UnknownDecisionException(attributes[i], attributeValues[i]);
- }
- compiled = false;
- examples.add(attributes, attributeValues, classification);
-
- return this;
- }
- public DecisionTree addExample(Map<String, String> attributes, boolean classification) throws UnknownDecisionException {
- compiled = false;
- examples.add(attributes, classification);
- return this;
- }
- public boolean apply(Map<String, String> data) throws BadDecisionException {
- compile();
- return rootAttribute.apply(data);
- }
- private Attribute compileWalk(Attribute current, Map<String, String> chosenAttributes, Set<String> usedAttributes) {
- // if the current attribute is a leaf, then there are no decisions and thus no
- // further attributes to find.
- if ( current.isLeaf() )
- return current;
- // get decisions for the current attribute (from this.decisions)
- String attributeName = current.getName();
- // remove this attribute from all further consideration
- usedAttributes.add(attributeName);
- for ( String decisionName : decisions.get(attributeName) ) {
- // overwrite the attribute decision for each value considered
- chosenAttributes.put(attributeName, decisionName);
- // find the next attribute to choose for the considered decision
- // build the subtree from this new attribute, pre-order
- // insert the newly-built subtree into the open decision slot
- current.addDecision(decisionName, compileWalk(algorithm.nextAttribute(chosenAttributes, usedAttributes), chosenAttributes, usedAttributes));
- }
- // remove the attribute decision before we walk back up the tree.
- chosenAttributes.remove(attributeName);
- // return the subtree so that it can be inserted into the parent tree.
- return current;
- }
- public void compile() {
- // skip compilation if already done.
- if ( compiled )
- return;
- // if no algorithm is set beforehand, select the default one.
- setDefaultAlgorithm();
- Map<String, String> chosenAttributes = new HashMap<String, String>();
- Set<String> usedAttributes = new HashSet<String>();
- if ( !decisionsSpecified )
- decisions = examples.extractDecisions();
- // find the root attribute (either leaf or non)
- // walk the tree, adding attributes as needed under each decision
- // save the original attribute as the root attribute.
- rootAttribute = compileWalk(algorithm.nextAttribute(chosenAttributes, usedAttributes), chosenAttributes, usedAttributes);
- compiled = true;
- }
- public String toString() {
- compile();
- if ( rootAttribute != null )
- return rootAttribute.toString();
- else
- return "";
- }
- public Attribute getRoot() {
- return rootAttribute;
- }
- }
- 复制代码
- https://github.com/saebyn/java-decision-tree/
复制代码
|
|