snowflake.rb 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # frozen_string_literal: true
  2. module Mastodon::Snowflake
  3. DEFAULT_REGEX = /timestamp_id\('(?<seq_prefix>\w+)'/
  4. class Callbacks
  5. def self.around_create(record)
  6. now = Time.now.utc
  7. if record.created_at.nil? || record.created_at >= now || record.created_at == record.updated_at || record.override_timestamps
  8. yield
  9. else
  10. record.id = Mastodon::Snowflake.id_at(record.created_at)
  11. tries = 0
  12. begin
  13. yield
  14. rescue ActiveRecord::RecordNotUnique
  15. raise if tries > 100
  16. tries += 1
  17. record.id += rand(100)
  18. retry
  19. end
  20. end
  21. end
  22. end
  23. class << self
  24. # Our ID will be composed of the following:
  25. # 6 bytes (48 bits) of millisecond-level timestamp
  26. # 2 bytes (16 bits) of sequence data
  27. #
  28. # The 'sequence data' is intended to be unique within a
  29. # given millisecond, yet obscure the 'serial number' of
  30. # this row.
  31. #
  32. # To do this, we hash the following data:
  33. # * Table name (if provided, skipped if not)
  34. # * Secret salt (should not be guessable)
  35. # * Timestamp (again, millisecond-level granularity)
  36. #
  37. # We then take the first two bytes of that value, and add
  38. # the lowest two bytes of the table ID sequence number
  39. # (`table_name`_id_seq). This means that even if we insert
  40. # two rows at the same millisecond, they will have
  41. # distinct 'sequence data' portions.
  42. #
  43. # If this happens, and an attacker can see both such IDs,
  44. # they can determine which of the two entries was inserted
  45. # first, but not the total number of entries in the table
  46. # (even mod 2**16).
  47. #
  48. # The table name is included in the hash to ensure that
  49. # different tables derive separate sequence bases so rows
  50. # inserted in the same millisecond in different tables do
  51. # not reveal the table ID sequence number for one another.
  52. #
  53. # The secret salt is included in the hash to ensure that
  54. # external users cannot derive the sequence base given the
  55. # timestamp and table name, which would allow them to
  56. # compute the table ID sequence number.
  57. def define_timestamp_id
  58. return if already_defined?
  59. connection.execute(sanitized_timestamp_id_sql)
  60. end
  61. def ensure_id_sequences_exist
  62. # Find tables using timestamp IDs.
  63. connection.tables.each do |table|
  64. # We're only concerned with "id" columns.
  65. next unless (id_col = connection.columns(table).find { |col| col.name == 'id' })
  66. # And only those that are using timestamp_id.
  67. next unless (data = DEFAULT_REGEX.match(id_col.default_function))
  68. seq_name = "#{data[:seq_prefix]}_id_seq"
  69. # If we were on Postgres 9.5+, we could do CREATE SEQUENCE IF
  70. # NOT EXISTS, but we can't depend on that. Instead, catch the
  71. # possible exception and ignore it.
  72. # Note that seq_name isn't a column name, but it's a
  73. # relation, like a column, and follows the same quoting rules
  74. # in Postgres.
  75. connection.execute(<<~SQL)
  76. DO $$
  77. BEGIN
  78. CREATE SEQUENCE #{connection.quote_column_name(seq_name)};
  79. EXCEPTION WHEN duplicate_table THEN
  80. -- Do nothing, we have the sequence already.
  81. END
  82. $$ LANGUAGE plpgsql;
  83. SQL
  84. end
  85. end
  86. def id_at(timestamp, with_random: true)
  87. id = timestamp.to_i * 1000
  88. id += rand(1000) if with_random
  89. id = id << 16
  90. id += rand(2**16) if with_random
  91. id
  92. end
  93. def to_time(id)
  94. Time.at((id >> 16) / 1000).utc
  95. end
  96. private
  97. def already_defined?
  98. connection.execute(<<~SQL.squish).values.first.first
  99. SELECT EXISTS(
  100. SELECT * FROM pg_proc WHERE proname = 'timestamp_id'
  101. );
  102. SQL
  103. end
  104. def sanitized_timestamp_id_sql
  105. ActiveRecord::Base.sanitize_sql_array(timestamp_id_sql_array)
  106. end
  107. def timestamp_id_sql_array
  108. [timestamp_id_sql_string, { random_string: SecureRandom.hex(16) }]
  109. end
  110. def timestamp_id_sql_string
  111. <<~SQL
  112. CREATE OR REPLACE FUNCTION timestamp_id(table_name text)
  113. RETURNS bigint AS
  114. $$
  115. DECLARE
  116. time_part bigint;
  117. sequence_base bigint;
  118. tail bigint;
  119. BEGIN
  120. time_part := (
  121. -- Get the time in milliseconds
  122. ((date_part('epoch', now()) * 1000))::bigint
  123. -- And shift it over two bytes
  124. << 16);
  125. sequence_base := (
  126. 'x' ||
  127. -- Take the first two bytes (four hex characters)
  128. substr(
  129. -- Of the MD5 hash of the data we documented
  130. md5(table_name || :random_string || time_part::text),
  131. 1, 4
  132. )
  133. -- And turn it into a bigint
  134. )::bit(16)::bigint;
  135. -- Finally, add our sequence number to our base, and chop
  136. -- it to the last two bytes
  137. tail := (
  138. (sequence_base + nextval(table_name || '_id_seq'))
  139. & 65535);
  140. -- Return the time part and the sequence part. OR appears
  141. -- faster here than addition, but they're equivalent:
  142. -- time_part has no trailing two bytes, and tail is only
  143. -- the last two bytes.
  144. RETURN time_part | tail;
  145. END
  146. $$ LANGUAGE plpgsql VOLATILE;
  147. SQL
  148. end
  149. def connection
  150. ActiveRecord::Base.connection
  151. end
  152. end
  153. end