subset.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #!/usr/bin/env python
  2. from sys import argv, exit, stdout, stderr
  3. from random import randint
  4. method = 0
  5. global n
  6. global dataset_filename
  7. subset_filename = ""
  8. rest_filename = ""
  9. def exit_with_help():
  10. print """\
  11. Usage: %s [options] dataset number [output1] [output2]
  12. This script selects a subset of the given dataset.
  13. options:
  14. -s method : method of selection (default 0)
  15. 0 -- stratified selection (classification only)
  16. 1 -- random selection
  17. output1 : the subset (optional)
  18. output2 : rest of the data (optional)
  19. If output1 is omitted, the subset will be printed on the screen.""" % argv[0]
  20. exit(1)
  21. def process_options():
  22. global method, n
  23. global dataset_filename, subset_filename, rest_filename
  24. argc = len(argv)
  25. if argc < 3:
  26. exit_with_help()
  27. i = 1
  28. while i < len(argv):
  29. if argv[i][0] != "-":
  30. break
  31. if argv[i] == "-s":
  32. i = i + 1
  33. method = int(argv[i])
  34. if method < 0 or method > 1:
  35. print "Unknown selection method %d" % (method)
  36. exit_with_help()
  37. i = i + 1
  38. dataset_filename = argv[i]
  39. n = int(argv[i+1])
  40. if i+2 < argc:
  41. subset_filename = argv[i+2]
  42. if i+3 < argc:
  43. rest_filename = argv[i+3]
  44. def main():
  45. class Label:
  46. def __init__(self, label, index, selected):
  47. self.label = label
  48. self.index = index
  49. self.selected = selected
  50. def __cmp__(self, other):
  51. return cmp(self.label, other.label)
  52. process_options()
  53. # get labels
  54. i = 0
  55. labels = []
  56. f = open(dataset_filename, 'r')
  57. for line in f:
  58. labels.append(Label(float((line.split())[0]), i, 0))
  59. i = i + 1
  60. f.close()
  61. l = i
  62. # determine where to output
  63. if subset_filename != "":
  64. file1 = open(subset_filename, 'w')
  65. else:
  66. file1 = stdout
  67. split = 0
  68. if rest_filename != "":
  69. split = 1
  70. file2 = open(rest_filename, 'w')
  71. # select the subset
  72. warning = 0
  73. if method == 0: # stratified
  74. labels.sort()
  75. label_end = labels[l-1].label + 1
  76. labels.append(Label(label_end, l, 0))
  77. begin = 0
  78. label = labels[begin].label
  79. for i in range(l+1):
  80. new_label = labels[i].label
  81. if new_label != label:
  82. nr_class = i - begin
  83. k = i*n/l - begin*n/l
  84. # at least one instance per class
  85. if k == 0:
  86. k = 1
  87. warning = warning + 1
  88. for j in range(nr_class):
  89. if randint(0, nr_class-j-1) < k:
  90. labels[begin+j].selected = 1
  91. k = k - 1
  92. begin = i
  93. label = new_label
  94. elif method == 1: # random
  95. k = n
  96. for i in range(l):
  97. if randint(0,l-i-1) < k:
  98. labels[i].selected = 1
  99. k = k - 1
  100. i = i + 1
  101. # output
  102. i = 0
  103. if method == 0:
  104. labels.sort(lambda x, y: cmp(int(x.index), int(y.index)))
  105. f = open(dataset_filename, 'r')
  106. for line in f:
  107. if labels[i].selected == 1:
  108. file1.write(line)
  109. else:
  110. if split == 1:
  111. file2.write(line)
  112. i = i + 1
  113. if warning > 0:
  114. stderr.write("""\
  115. Warning:
  116. 1. You may have regression data. Please use -s 1.
  117. 2. Classification data unbalanced or too small. We select at least 1 per class.
  118. The subset thus contains %d instances.
  119. """ % (n+warning))
  120. # cleanup
  121. f.close()
  122. file1.close()
  123. if split == 1:
  124. file2.close()
  125. main()