db.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package db
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "strings"
  7. fp "github.com/cloudflare/mitmengine/fputil"
  8. )
  9. // A Database contains a collection of records containing software signatures.
  10. type Database struct {
  11. Records []Record
  12. }
  13. // NewDatabase returns a new Database initialized from the configuration.
  14. func NewDatabase(input io.Reader) (Database, error) {
  15. var a Database
  16. // get exact length
  17. a.Records = []Record{}
  18. err := a.Load(input)
  19. return a, err
  20. }
  21. // Load records from input into the database, and return an error on bad records.
  22. func (a *Database) Load(input io.Reader) error {
  23. var record Record
  24. scanner := bufio.NewScanner(input)
  25. for scanner.Scan() {
  26. recordString := scanner.Text()
  27. if idx := strings.IndexRune(recordString, '\t'); idx != -1 {
  28. // remove anything before a tab
  29. recordString = recordString[idx+1:]
  30. }
  31. if idx := strings.IndexRune(recordString, '#'); idx != -1 {
  32. // remove comments at end of lines
  33. recordString = recordString[:idx]
  34. }
  35. // remove any whitespace or quotes
  36. recordString = strings.Trim(strings.TrimSpace(recordString), "\"")
  37. if len(recordString) == 0 {
  38. continue // skip empty lines
  39. }
  40. if err := record.Parse(recordString); err != nil {
  41. return fmt.Errorf("unable to parse record: %s, %s", recordString, err)
  42. }
  43. a.Add(record)
  44. }
  45. return nil
  46. }
  47. // Len returns the length of the database
  48. func (a *Database) Len() int {
  49. return len(a.Records)
  50. }
  51. // Add a single record to the database.
  52. func (a *Database) Add(record Record) int {
  53. a.Records = append(a.Records, record)
  54. return len(a.Records)
  55. }
  56. // Clear all records from the database.
  57. func (a *Database) Clear() {
  58. a.Records = []Record{}
  59. }
  60. // Dump records in the database to output.
  61. func (a Database) Dump(output io.Writer) error {
  62. for _, record := range a.Records {
  63. _, err := fmt.Fprintln(output, record)
  64. if err != nil {
  65. return err
  66. }
  67. }
  68. return nil
  69. }
  70. // GetByRequestFingerprint returns all records in the database matching the
  71. // request fingerprint.
  72. func (a Database) GetByRequestFingerprint(requestFingerprint fp.RequestFingerprint) []int {
  73. return a.GetBy(func(r Record) bool {
  74. match, _ := r.RequestSignature.Match(requestFingerprint)
  75. return match != fp.MatchImpossible
  76. })
  77. }
  78. // GetByUAFingerprint returns all records in the database matching the
  79. // user agent fingerprint.
  80. func (a Database) GetByUAFingerprint(uaFingerprint fp.UAFingerprint) []int {
  81. return a.GetBy(func(r Record) bool { return r.UASignature.Match(uaFingerprint) != fp.MatchImpossible })
  82. }
  83. // GetBy returns a list of records for which GetBy returns true.
  84. func (a Database) GetBy(getFunc func(Record) bool) []int {
  85. var recordIds []int
  86. for id, record := range a.Records {
  87. if getFunc(record) {
  88. recordIds = append(recordIds, id)
  89. }
  90. }
  91. return recordIds
  92. }
  93. // DeleteBy deletes records for which rmFunc returns true.
  94. func (a *Database) DeleteBy(deleteFunc func(Record) bool) {
  95. recordIds := a.GetBy(deleteFunc)
  96. for _, id := range recordIds {
  97. a.Records = append(a.Records[:id], a.Records[id+1:]...)
  98. }
  99. }
  100. // MergeBy merges records for which mergeFunc returns true.
  101. func (a *Database) MergeBy(mergeFunc func(Record, Record) bool) (int, int) {
  102. before := len(a.Records)
  103. for id1 := 0; id1 < len(a.Records); id1++ {
  104. for id2 := 0; id2 < len(a.Records); id2++ {
  105. if id1 == id2 {
  106. continue
  107. }
  108. record1 := a.Records[id1]
  109. record2 := a.Records[id2]
  110. if mergeFunc(record1, record2) {
  111. a.Records[id1] = record1.Merge(record2)
  112. // If elements are deleted from the map during the iteration, they will not be produced.
  113. // https://golang.org/ref/spec#For_statements
  114. a.Records = append(a.Records[:id2], a.Records[id2+1:]...)
  115. }
  116. }
  117. }
  118. return before, len(a.Records)
  119. }