opencv_parser.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Copyright 2011 The Graphics-Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package detect
  5. import (
  6. "bytes"
  7. "encoding/xml"
  8. "errors"
  9. "fmt"
  10. "image"
  11. "io"
  12. "io/ioutil"
  13. "strconv"
  14. "strings"
  15. )
  16. type xmlFeature struct {
  17. Rects []string `xml:"grp>feature>rects>grp"`
  18. Tilted int `xml:"grp>feature>tilted"`
  19. Threshold float64 `xml:"grp>threshold"`
  20. Left float64 `xml:"grp>left_val"`
  21. Right float64 `xml:"grp>right_val"`
  22. }
  23. type xmlStages struct {
  24. Trees []xmlFeature `xml:"trees>grp"`
  25. Stage_threshold float64 `xml:"stage_threshold"`
  26. Parent int `xml:"parent"`
  27. Next int `xml:"next"`
  28. }
  29. type opencv_storage struct {
  30. Any struct {
  31. XMLName xml.Name
  32. Type string `xml:"type_id,attr"`
  33. Size string `xml:"size"`
  34. Stages []xmlStages `xml:"stages>grp"`
  35. } `xml:",any"`
  36. }
  37. func buildFeature(r string) (f Feature, err error) {
  38. var x, y, w, h int
  39. var weight float64
  40. _, err = fmt.Sscanf(r, "%d %d %d %d %f", &x, &y, &w, &h, &weight)
  41. if err != nil {
  42. return
  43. }
  44. f.Rect = image.Rect(x, y, x+w, y+h)
  45. f.Weight = weight
  46. return
  47. }
  48. func buildCascade(s *opencv_storage) (c *Cascade, name string, err error) {
  49. if s.Any.Type != "opencv-haar-classifier" {
  50. err = fmt.Errorf("got %s want opencv-haar-classifier", s.Any.Type)
  51. return
  52. }
  53. name = s.Any.XMLName.Local
  54. c = &Cascade{}
  55. sizes := strings.Split(s.Any.Size, " ")
  56. w, err := strconv.Atoi(sizes[0])
  57. if err != nil {
  58. return nil, "", err
  59. }
  60. h, err := strconv.Atoi(sizes[1])
  61. if err != nil {
  62. return nil, "", err
  63. }
  64. c.Size = image.Pt(w, h)
  65. c.Stage = []CascadeStage{}
  66. for _, stage := range s.Any.Stages {
  67. cs := CascadeStage{
  68. Classifier: []Classifier{},
  69. Threshold: stage.Stage_threshold,
  70. }
  71. for _, tree := range stage.Trees {
  72. if tree.Tilted != 0 {
  73. err = errors.New("Cascade does not support tilted features")
  74. return
  75. }
  76. cls := Classifier{
  77. Feature: []Feature{},
  78. Threshold: tree.Threshold,
  79. Left: tree.Left,
  80. Right: tree.Right,
  81. }
  82. for _, rect := range tree.Rects {
  83. f, err := buildFeature(rect)
  84. if err != nil {
  85. return nil, "", err
  86. }
  87. cls.Feature = append(cls.Feature, f)
  88. }
  89. cs.Classifier = append(cs.Classifier, cls)
  90. }
  91. c.Stage = append(c.Stage, cs)
  92. }
  93. return
  94. }
  95. // ParseOpenCV produces a detection Cascade from an OpenCV XML file.
  96. func ParseOpenCV(r io.Reader) (cascade *Cascade, name string, err error) {
  97. // BUG(crawshaw): tag-based parsing doesn't seem to work with <_>
  98. buf, err := ioutil.ReadAll(r)
  99. if err != nil {
  100. return
  101. }
  102. buf = bytes.Replace(buf, []byte("<_>"), []byte("<grp>"), -1)
  103. buf = bytes.Replace(buf, []byte("</_>"), []byte("</grp>"), -1)
  104. s := &opencv_storage{}
  105. err = xml.Unmarshal(buf, s)
  106. if err != nil {
  107. return
  108. }
  109. return buildCascade(s)
  110. }