score.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. // License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
  2. package subseq
  3. import (
  4. "fmt"
  5. "slices"
  6. "strings"
  7. "kitty/tools/utils"
  8. "kitty/tools/utils/images"
  9. )
  10. var _ = fmt.Print
  11. const (
  12. LEVEL1 = "/"
  13. LEVEL2 = "-_0123456789"
  14. LEVEL3 = "."
  15. )
  16. type resolved_options_type struct {
  17. level1, level2, level3 []rune
  18. }
  19. type Options struct {
  20. Level1, Level2, Level3 string
  21. NumberOfThreads int
  22. }
  23. type Match struct {
  24. Positions []int
  25. Score float64
  26. idx int
  27. Text string
  28. }
  29. func level_factor_for(current_lcase, last_lcase, current_cased, last_cased rune, opts *resolved_options_type) int {
  30. switch {
  31. case slices.Contains(opts.level1, last_lcase):
  32. return 90
  33. case slices.Contains(opts.level2, last_lcase):
  34. return 80
  35. case last_lcase == last_cased && current_lcase != current_cased: // camelCase
  36. return 80
  37. case slices.Contains(opts.level3, last_lcase):
  38. return 70
  39. default:
  40. return 0
  41. }
  42. }
  43. type workspace_type struct {
  44. positions [][]int // positions of each needle char in haystack
  45. level_factors []int
  46. address []int
  47. max_score_per_char float64
  48. }
  49. func (w *workspace_type) initialize(haystack_sz, needle_sz int) {
  50. if cap(w.positions) < needle_sz {
  51. w.positions = make([][]int, needle_sz)
  52. } else {
  53. w.positions = w.positions[:needle_sz]
  54. }
  55. if cap(w.level_factors) < haystack_sz {
  56. w.level_factors = make([]int, 2*haystack_sz)
  57. } else {
  58. w.level_factors = w.level_factors[:haystack_sz]
  59. }
  60. for i, s := range w.positions {
  61. if cap(s) < haystack_sz {
  62. w.positions[i] = make([]int, 0, 2*haystack_sz)
  63. } else {
  64. w.positions[i] = w.positions[i][:0]
  65. }
  66. }
  67. if cap(w.address) < needle_sz {
  68. w.address = make([]int, needle_sz)
  69. }
  70. w.address = utils.Memset(w.address)
  71. }
  72. func (w *workspace_type) position(x int) int { // the position of xth needle char in the haystack for the current address
  73. return w.positions[x][w.address[x]]
  74. }
  75. func (w *workspace_type) increment_address() bool {
  76. pos := len(w.positions) - 1 // the last needle char
  77. for {
  78. w.address[pos]++
  79. if w.address[pos] < len(w.positions[pos]) {
  80. return true
  81. }
  82. if pos == 0 {
  83. break
  84. }
  85. w.address[pos] = 0
  86. pos--
  87. }
  88. return false
  89. }
  90. func (w *workspace_type) address_is_monotonic() bool {
  91. // Check if the character positions pointed to by the current address are monotonic
  92. for i := 1; i < len(w.positions); i++ {
  93. if w.position(i) <= w.position(i-1) {
  94. return false
  95. }
  96. }
  97. return true
  98. }
  99. func (w *workspace_type) calc_score() (ans float64) {
  100. distance, pos := 0, 0
  101. for i := 0; i < len(w.positions); i++ {
  102. pos = w.position(i)
  103. if i == 0 {
  104. distance = pos + 1
  105. } else {
  106. distance = pos - w.position(i-1)
  107. if distance < 2 {
  108. ans += w.max_score_per_char // consecutive chars
  109. continue
  110. }
  111. }
  112. if w.level_factors[pos] > 0 {
  113. ans += (100.0 * w.max_score_per_char) / float64(w.level_factors[pos]) // at a special location
  114. } else {
  115. ans += (0.75 * w.max_score_per_char) / float64(distance)
  116. }
  117. }
  118. return
  119. }
  120. func has_atleast_one_match(w *workspace_type) (found bool) {
  121. p := -1
  122. for i := 0; i < len(w.positions); i++ {
  123. if len(w.positions[i]) == 0 { // all chars of needle not in haystack
  124. return false
  125. }
  126. found = false
  127. for _, pos := range w.positions[i] {
  128. if pos > p {
  129. p = pos
  130. found = true
  131. break
  132. }
  133. }
  134. if !found { // chars of needle not present in sequence in haystack
  135. return false
  136. }
  137. }
  138. return true
  139. }
  140. func score_item(item string, idx int, needle []rune, opts *resolved_options_type, w *workspace_type) *Match {
  141. ans := &Match{idx: idx, Text: item, Positions: make([]int, len(needle))}
  142. haystack := []rune(strings.ToLower(item))
  143. orig_haystack := []rune(item)
  144. w.initialize(len(orig_haystack), len(needle))
  145. for i := 0; i < len(haystack); i++ {
  146. level_factor_calculated := false
  147. for j := 0; j < len(needle); j++ {
  148. if needle[j] == haystack[i] {
  149. if !level_factor_calculated {
  150. level_factor_calculated = true
  151. if i > 0 {
  152. w.level_factors[i] = level_factor_for(haystack[i], haystack[i-1], orig_haystack[i], orig_haystack[i-1], opts)
  153. }
  154. }
  155. w.positions[j] = append(w.positions[j], i)
  156. }
  157. }
  158. }
  159. w.max_score_per_char = (1.0/float64(len(orig_haystack)) + 1.0/float64(len(needle))) / 2.0
  160. if !has_atleast_one_match(w) {
  161. return ans
  162. }
  163. var score float64
  164. for {
  165. if w.address_is_monotonic() {
  166. score = w.calc_score()
  167. if score > ans.Score {
  168. ans.Score = score
  169. for i := range ans.Positions {
  170. ans.Positions[i] = w.position(i)
  171. }
  172. }
  173. }
  174. if !w.increment_address() {
  175. break
  176. }
  177. }
  178. if ans.Score > 0 {
  179. adjust := utils.RuneOffsetsToByteOffsets(item)
  180. for i := range ans.Positions {
  181. ans.Positions[i] = adjust(ans.Positions[i])
  182. }
  183. }
  184. return ans
  185. }
  186. func ScoreItems(query string, items []string, opts Options) []*Match {
  187. ctx := images.Context{}
  188. ctx.SetNumberOfThreads(opts.NumberOfThreads)
  189. ans := make([]*Match, len(items))
  190. results := make(chan *Match, len(items))
  191. nr := []rune(strings.ToLower(query))
  192. if opts.Level1 == "" {
  193. opts.Level1 = LEVEL1
  194. }
  195. if opts.Level2 == "" {
  196. opts.Level2 = LEVEL2
  197. }
  198. if opts.Level3 == "" {
  199. opts.Level3 = LEVEL3
  200. }
  201. ropts := resolved_options_type{
  202. level1: []rune(opts.Level1), level2: []rune(opts.Level2), level3: []rune(opts.Level3),
  203. }
  204. ctx.Parallel(0, len(items), func(nums <-chan int) {
  205. w := workspace_type{}
  206. for i := range nums {
  207. results <- score_item(items[i], i, nr, &ropts, &w)
  208. }
  209. })
  210. close(results)
  211. for x := range results {
  212. ans[x.idx] = x
  213. }
  214. return ans
  215. }