index.js 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529
  1. // @ts-check
  2. const fs = require('fs');
  3. const http = require('http');
  4. const url = require('url');
  5. const dotenv = require('dotenv');
  6. const express = require('express');
  7. const Redis = require('ioredis');
  8. const { JSDOM } = require('jsdom');
  9. const log = require('npmlog');
  10. const pg = require('pg');
  11. const dbUrlToConfig = require('pg-connection-string').parse;
  12. const metrics = require('prom-client');
  13. const uuid = require('uuid');
  14. const WebSocket = require('ws');
  15. const environment = process.env.NODE_ENV || 'development';
  16. dotenv.config({
  17. path: environment === 'production' ? '.env.production' : '.env',
  18. });
  19. log.level = process.env.LOG_LEVEL || 'verbose';
  20. /**
  21. * @param {Object.<string, any>} config
  22. */
  23. const createRedisClient = async (config) => {
  24. const { redisParams, redisUrl } = config;
  25. const client = new Redis(redisUrl, redisParams);
  26. client.on('error', (err) => log.error('Redis Client Error!', err));
  27. return client;
  28. };
  29. /**
  30. * Attempts to safely parse a string as JSON, used when both receiving a message
  31. * from redis and when receiving a message from a client over a websocket
  32. * connection, this is why it accepts a `req` argument.
  33. * @param {string} json
  34. * @param {any?} req
  35. * @returns {Object.<string, any>|null}
  36. */
  37. const parseJSON = (json, req) => {
  38. try {
  39. return JSON.parse(json);
  40. } catch (err) {
  41. /* FIXME: This logging isn't great, and should probably be done at the
  42. * call-site of parseJSON, not in the method, but this would require changing
  43. * the signature of parseJSON to return something akin to a Result type:
  44. * [Error|null, null|Object<string,any}], and then handling the error
  45. * scenarios.
  46. */
  47. if (req) {
  48. if (req.accountId) {
  49. log.warn(req.requestId, `Error parsing message from user ${req.accountId}: ${err}`);
  50. } else {
  51. log.silly(req.requestId, `Error parsing message from ${req.remoteAddress}: ${err}`);
  52. }
  53. } else {
  54. log.warn(`Error parsing message from redis: ${err}`);
  55. }
  56. return null;
  57. }
  58. };
  59. /**
  60. * @param {Object.<string, any>} env the `process.env` value to read configuration from
  61. * @returns {Object.<string, any>} the configuration for the PostgreSQL connection
  62. */
  63. const pgConfigFromEnv = (env) => {
  64. const pgConfigs = {
  65. development: {
  66. user: env.DB_USER || pg.defaults.user,
  67. password: env.DB_PASS || pg.defaults.password,
  68. database: env.DB_NAME || 'mastodon_development',
  69. host: env.DB_HOST || pg.defaults.host,
  70. port: env.DB_PORT || pg.defaults.port,
  71. },
  72. production: {
  73. user: env.DB_USER || 'mastodon',
  74. password: env.DB_PASS || '',
  75. database: env.DB_NAME || 'mastodon_production',
  76. host: env.DB_HOST || 'localhost',
  77. port: env.DB_PORT || 5432,
  78. },
  79. };
  80. let baseConfig;
  81. if (env.DATABASE_URL) {
  82. baseConfig = dbUrlToConfig(env.DATABASE_URL);
  83. // Support overriding the database password in the connection URL
  84. if (!baseConfig.password && env.DB_PASS) {
  85. baseConfig.password = env.DB_PASS;
  86. }
  87. } else {
  88. baseConfig = pgConfigs[environment];
  89. if (env.DB_SSLMODE) {
  90. switch(env.DB_SSLMODE) {
  91. case 'disable':
  92. case '':
  93. baseConfig.ssl = false;
  94. break;
  95. case 'no-verify':
  96. baseConfig.ssl = { rejectUnauthorized: false };
  97. break;
  98. default:
  99. baseConfig.ssl = {};
  100. break;
  101. }
  102. }
  103. }
  104. return {
  105. ...baseConfig,
  106. max: env.DB_POOL || 10,
  107. connectionTimeoutMillis: 15000,
  108. application_name: '',
  109. };
  110. };
  111. /**
  112. * @param {Object.<string, any>} env the `process.env` value to read configuration from
  113. * @returns {Object.<string, any>} configuration for the Redis connection
  114. */
  115. const redisConfigFromEnv = (env) => {
  116. // ioredis *can* transparently add prefixes for us, but it doesn't *in some cases*,
  117. // which means we can't use it. But this is something that should be looked into.
  118. const redisPrefix = env.REDIS_NAMESPACE ? `${env.REDIS_NAMESPACE}:` : '';
  119. const redisParams = {
  120. host: env.REDIS_HOST || '127.0.0.1',
  121. port: env.REDIS_PORT || 6379,
  122. db: env.REDIS_DB || 0,
  123. password: env.REDIS_PASSWORD || undefined,
  124. };
  125. // redisParams.path takes precedence over host and port.
  126. if (env.REDIS_URL && env.REDIS_URL.startsWith('unix://')) {
  127. redisParams.path = env.REDIS_URL.slice(7);
  128. }
  129. return {
  130. redisParams,
  131. redisPrefix,
  132. redisUrl: env.REDIS_URL,
  133. };
  134. };
  135. const PUBLIC_CHANNELS = [
  136. 'public',
  137. 'public:media',
  138. 'public:local',
  139. 'public:local:media',
  140. 'public:remote',
  141. 'public:remote:media',
  142. 'hashtag',
  143. 'hashtag:local',
  144. ];
  145. // Used for priming the counters/gauges for the various metrics that are
  146. // per-channel
  147. const CHANNEL_NAMES = [
  148. 'system',
  149. 'user',
  150. 'user:notification',
  151. 'list',
  152. 'direct',
  153. ...PUBLIC_CHANNELS
  154. ];
  155. const startServer = async () => {
  156. const app = express();
  157. app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
  158. const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
  159. const server = http.createServer(app);
  160. /**
  161. * @type {Object.<string, Array.<function(Object<string, any>): void>>}
  162. */
  163. const subs = {};
  164. const redisConfig = redisConfigFromEnv(process.env);
  165. const redisSubscribeClient = await createRedisClient(redisConfig);
  166. const redisClient = await createRedisClient(redisConfig);
  167. const { redisPrefix } = redisConfig;
  168. // Collect metrics from Node.js
  169. metrics.collectDefaultMetrics();
  170. new metrics.Gauge({
  171. name: 'pg_pool_total_connections',
  172. help: 'The total number of clients existing within the pool',
  173. collect() {
  174. this.set(pgPool.totalCount);
  175. },
  176. });
  177. new metrics.Gauge({
  178. name: 'pg_pool_idle_connections',
  179. help: 'The number of clients which are not checked out but are currently idle in the pool',
  180. collect() {
  181. this.set(pgPool.idleCount);
  182. },
  183. });
  184. new metrics.Gauge({
  185. name: 'pg_pool_waiting_queries',
  186. help: 'The number of queued requests waiting on a client when all clients are checked out',
  187. collect() {
  188. this.set(pgPool.waitingCount);
  189. },
  190. });
  191. const connectedClients = new metrics.Gauge({
  192. name: 'connected_clients',
  193. help: 'The number of clients connected to the streaming server',
  194. labelNames: ['type'],
  195. });
  196. const connectedChannels = new metrics.Gauge({
  197. name: 'connected_channels',
  198. help: 'The number of channels the streaming server is streaming to',
  199. labelNames: [ 'type', 'channel' ]
  200. });
  201. const redisSubscriptions = new metrics.Gauge({
  202. name: 'redis_subscriptions',
  203. help: 'The number of Redis channels the streaming server is subscribed to',
  204. });
  205. const redisMessagesReceived = new metrics.Counter({
  206. name: 'redis_messages_received_total',
  207. help: 'The total number of messages the streaming server has received from redis subscriptions'
  208. });
  209. const messagesSent = new metrics.Counter({
  210. name: 'messages_sent_total',
  211. help: 'The total number of messages the streaming server sent to clients per connection type',
  212. labelNames: [ 'type' ]
  213. });
  214. // Prime the gauges so we don't loose metrics between restarts:
  215. redisSubscriptions.set(0);
  216. connectedClients.set({ type: 'websocket' }, 0);
  217. connectedClients.set({ type: 'eventsource' }, 0);
  218. // For each channel, initialize the gauges at zero; There's only a finite set of channels available
  219. CHANNEL_NAMES.forEach(( channel ) => {
  220. connectedChannels.set({ type: 'websocket', channel }, 0);
  221. connectedChannels.set({ type: 'eventsource', channel }, 0);
  222. })
  223. // Prime the counters so that we don't loose metrics between restarts.
  224. // Unfortunately counters don't support the set() API, so instead I'm using
  225. // inc(0) to achieve the same result.
  226. redisMessagesReceived.inc(0);
  227. messagesSent.inc({ type: 'websocket' }, 0);
  228. messagesSent.inc({ type: 'eventsource' }, 0);
  229. // When checking metrics in the browser, the favicon is requested this
  230. // prevents the request from falling through to the API Router, which would
  231. // error for this endpoint:
  232. app.get('/favicon.ico', (req, res) => res.status(404).end());
  233. app.get('/api/v1/streaming/health', (req, res) => {
  234. res.writeHead(200, { 'Content-Type': 'text/plain' });
  235. res.end('OK');
  236. });
  237. app.get('/metrics', async (req, res) => {
  238. try {
  239. res.set('Content-Type', metrics.register.contentType);
  240. res.end(await metrics.register.metrics());
  241. } catch (ex) {
  242. log.error(ex);
  243. res.status(500).end();
  244. }
  245. });
  246. /**
  247. * @param {string[]} channels
  248. * @returns {function(): void}
  249. */
  250. const subscriptionHeartbeat = channels => {
  251. const interval = 6 * 60;
  252. const tellSubscribed = () => {
  253. channels.forEach(channel => redisClient.set(`${redisPrefix}subscribed:${channel}`, '1', 'EX', interval * 3));
  254. };
  255. tellSubscribed();
  256. const heartbeat = setInterval(tellSubscribed, interval * 1000);
  257. return () => {
  258. clearInterval(heartbeat);
  259. };
  260. };
  261. /**
  262. * @param {string} channel
  263. * @param {string} message
  264. */
  265. const onRedisMessage = (channel, message) => {
  266. redisMessagesReceived.inc();
  267. const callbacks = subs[channel];
  268. log.silly(`New message on channel ${redisPrefix}${channel}`);
  269. if (!callbacks) {
  270. return;
  271. }
  272. const json = parseJSON(message, null);
  273. if (!json) return;
  274. callbacks.forEach(callback => callback(json));
  275. };
  276. redisSubscribeClient.on("message", onRedisMessage);
  277. /**
  278. * @callback SubscriptionListener
  279. * @param {ReturnType<parseJSON>} json of the message
  280. * @returns void
  281. */
  282. /**
  283. * @param {string} channel
  284. * @param {SubscriptionListener} callback
  285. */
  286. const subscribe = (channel, callback) => {
  287. log.silly(`Adding listener for ${channel}`);
  288. subs[channel] = subs[channel] || [];
  289. if (subs[channel].length === 0) {
  290. log.verbose(`Subscribe ${channel}`);
  291. redisSubscribeClient.subscribe(channel, (err, count) => {
  292. if (err) {
  293. log.error(`Error subscribing to ${channel}`);
  294. }
  295. else {
  296. redisSubscriptions.set(count);
  297. }
  298. });
  299. }
  300. subs[channel].push(callback);
  301. };
  302. /**
  303. * @param {string} channel
  304. * @param {SubscriptionListener} callback
  305. */
  306. const unsubscribe = (channel, callback) => {
  307. log.silly(`Removing listener for ${channel}`);
  308. if (!subs[channel]) {
  309. return;
  310. }
  311. subs[channel] = subs[channel].filter(item => item !== callback);
  312. if (subs[channel].length === 0) {
  313. log.verbose(`Unsubscribe ${channel}`);
  314. redisSubscribeClient.unsubscribe(channel, (err, count) => {
  315. if (err) {
  316. log.error(`Error unsubscribing to ${channel}`);
  317. }
  318. else {
  319. redisSubscriptions.set(count);
  320. }
  321. });
  322. delete subs[channel];
  323. }
  324. };
  325. const FALSE_VALUES = [
  326. false,
  327. 0,
  328. '0',
  329. 'f',
  330. 'F',
  331. 'false',
  332. 'FALSE',
  333. 'off',
  334. 'OFF',
  335. ];
  336. /**
  337. * @param {any} value
  338. * @returns {boolean}
  339. */
  340. const isTruthy = value =>
  341. value && !FALSE_VALUES.includes(value);
  342. /**
  343. * @param {any} req
  344. * @param {any} res
  345. * @param {function(Error=): void} next
  346. */
  347. const allowCrossDomain = (req, res, next) => {
  348. res.header('Access-Control-Allow-Origin', '*');
  349. res.header('Access-Control-Allow-Headers', 'Authorization, Accept, Cache-Control');
  350. res.header('Access-Control-Allow-Methods', 'GET, OPTIONS');
  351. next();
  352. };
  353. /**
  354. * @param {any} req
  355. * @param {any} res
  356. * @param {function(Error=): void} next
  357. */
  358. const setRequestId = (req, res, next) => {
  359. req.requestId = uuid.v4();
  360. res.header('X-Request-Id', req.requestId);
  361. next();
  362. };
  363. /**
  364. * @param {any} req
  365. * @param {any} res
  366. * @param {function(Error=): void} next
  367. */
  368. const setRemoteAddress = (req, res, next) => {
  369. req.remoteAddress = req.connection.remoteAddress;
  370. next();
  371. };
  372. /**
  373. * @param {any} req
  374. * @param {string[]} necessaryScopes
  375. * @returns {boolean}
  376. */
  377. const isInScope = (req, necessaryScopes) =>
  378. req.scopes.some(scope => necessaryScopes.includes(scope));
  379. /**
  380. * @param {string} token
  381. * @param {any} req
  382. * @returns {Promise.<void>}
  383. */
  384. const accountFromToken = (token, req) => new Promise((resolve, reject) => {
  385. pgPool.connect((err, client, done) => {
  386. if (err) {
  387. reject(err);
  388. return;
  389. }
  390. client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
  391. done();
  392. if (err) {
  393. reject(err);
  394. return;
  395. }
  396. if (result.rows.length === 0) {
  397. err = new Error('Invalid access token');
  398. err.status = 401;
  399. reject(err);
  400. return;
  401. }
  402. req.accessTokenId = result.rows[0].id;
  403. req.scopes = result.rows[0].scopes.split(' ');
  404. req.accountId = result.rows[0].account_id;
  405. req.chosenLanguages = result.rows[0].chosen_languages;
  406. req.deviceId = result.rows[0].device_id;
  407. resolve();
  408. });
  409. });
  410. });
  411. /**
  412. * @param {any} req
  413. * @returns {Promise.<void>}
  414. */
  415. const accountFromRequest = (req) => new Promise((resolve, reject) => {
  416. const authorization = req.headers.authorization;
  417. const location = url.parse(req.url, true);
  418. const accessToken = location.query.access_token || req.headers['sec-websocket-protocol'];
  419. if (!authorization && !accessToken) {
  420. const err = new Error('Missing access token');
  421. err.status = 401;
  422. reject(err);
  423. return;
  424. }
  425. const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken;
  426. resolve(accountFromToken(token, req));
  427. });
  428. /**
  429. * @param {any} req
  430. * @returns {string|undefined}
  431. */
  432. const channelNameFromPath = req => {
  433. const { path, query } = req;
  434. const onlyMedia = isTruthy(query.only_media);
  435. switch (path) {
  436. case '/api/v1/streaming/user':
  437. return 'user';
  438. case '/api/v1/streaming/user/notification':
  439. return 'user:notification';
  440. case '/api/v1/streaming/public':
  441. return onlyMedia ? 'public:media' : 'public';
  442. case '/api/v1/streaming/public/local':
  443. return onlyMedia ? 'public:local:media' : 'public:local';
  444. case '/api/v1/streaming/public/remote':
  445. return onlyMedia ? 'public:remote:media' : 'public:remote';
  446. case '/api/v1/streaming/hashtag':
  447. return 'hashtag';
  448. case '/api/v1/streaming/hashtag/local':
  449. return 'hashtag:local';
  450. case '/api/v1/streaming/direct':
  451. return 'direct';
  452. case '/api/v1/streaming/list':
  453. return 'list';
  454. default:
  455. return undefined;
  456. }
  457. };
  458. /**
  459. * @param {any} req
  460. * @param {string|undefined} channelName
  461. * @returns {Promise.<void>}
  462. */
  463. const checkScopes = (req, channelName) => new Promise((resolve, reject) => {
  464. log.silly(req.requestId, `Checking OAuth scopes for ${channelName}`);
  465. // When accessing public channels, no scopes are needed
  466. if (PUBLIC_CHANNELS.includes(channelName)) {
  467. resolve();
  468. return;
  469. }
  470. // The `read` scope has the highest priority, if the token has it
  471. // then it can access all streams
  472. const requiredScopes = ['read'];
  473. // When accessing specifically the notifications stream,
  474. // we need a read:notifications, while in all other cases,
  475. // we can allow access with read:statuses. Mind that the
  476. // user stream will not contain notifications unless
  477. // the token has either read or read:notifications scope
  478. // as well, this is handled separately.
  479. if (channelName === 'user:notification') {
  480. requiredScopes.push('read:notifications');
  481. } else {
  482. requiredScopes.push('read:statuses');
  483. }
  484. if (req.scopes && requiredScopes.some(requiredScope => req.scopes.includes(requiredScope))) {
  485. resolve();
  486. return;
  487. }
  488. const err = new Error('Access token does not cover required scopes');
  489. err.status = 401;
  490. reject(err);
  491. });
  492. /**
  493. * @param {any} info
  494. * @param {function(boolean, number, string): void} callback
  495. */
  496. const wsVerifyClient = (info, callback) => {
  497. // When verifying the websockets connection, we no longer pre-emptively
  498. // check OAuth scopes and drop the connection if they're missing. We only
  499. // drop the connection if access without token is not allowed by environment
  500. // variables. OAuth scope checks are moved to the point of subscription
  501. // to a specific stream.
  502. accountFromRequest(info.req).then(() => {
  503. callback(true, undefined, undefined);
  504. }).catch(err => {
  505. log.error(info.req.requestId, err.toString());
  506. callback(false, 401, 'Unauthorized');
  507. });
  508. };
  509. /**
  510. * @typedef SystemMessageHandlers
  511. * @property {function(): void} onKill
  512. */
  513. /**
  514. * @param {any} req
  515. * @param {SystemMessageHandlers} eventHandlers
  516. * @returns {function(object): void}
  517. */
  518. const createSystemMessageListener = (req, eventHandlers) => {
  519. return message => {
  520. const { event } = message;
  521. log.silly(req.requestId, `System message for ${req.accountId}: ${event}`);
  522. if (event === 'kill') {
  523. log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`);
  524. eventHandlers.onKill();
  525. } else if (event === 'filters_changed') {
  526. log.verbose(req.requestId, `Invalidating filters cache for ${req.accountId}`);
  527. req.cachedFilters = null;
  528. }
  529. };
  530. };
  531. /**
  532. * @param {any} req
  533. * @param {any} res
  534. */
  535. const subscribeHttpToSystemChannel = (req, res) => {
  536. const accessTokenChannelId = `timeline:access_token:${req.accessTokenId}`;
  537. const systemChannelId = `timeline:system:${req.accountId}`;
  538. const listener = createSystemMessageListener(req, {
  539. onKill() {
  540. res.end();
  541. },
  542. });
  543. res.on('close', () => {
  544. unsubscribe(`${redisPrefix}${accessTokenChannelId}`, listener);
  545. unsubscribe(`${redisPrefix}${systemChannelId}`, listener);
  546. connectedChannels.labels({ type: 'eventsource', channel: 'system' }).dec(2);
  547. });
  548. subscribe(`${redisPrefix}${accessTokenChannelId}`, listener);
  549. subscribe(`${redisPrefix}${systemChannelId}`, listener);
  550. connectedChannels.labels({ type: 'eventsource', channel: 'system' }).inc(2);
  551. };
  552. /**
  553. * @param {any} req
  554. * @param {any} res
  555. * @param {function(Error=): void} next
  556. */
  557. const authenticationMiddleware = (req, res, next) => {
  558. if (req.method === 'OPTIONS') {
  559. next();
  560. return;
  561. }
  562. const channelName = channelNameFromPath(req);
  563. // If no channelName can be found for the request, then we should terminate
  564. // the connection, as there's nothing to stream back
  565. if (!channelName) {
  566. const err = new Error('Unknown channel requested');
  567. err.status = 400;
  568. next(err);
  569. return;
  570. }
  571. accountFromRequest(req).then(() => checkScopes(req, channelName)).then(() => {
  572. subscribeHttpToSystemChannel(req, res);
  573. }).then(() => {
  574. next();
  575. }).catch(err => {
  576. next(err);
  577. });
  578. };
  579. /**
  580. * @param {Error} err
  581. * @param {any} req
  582. * @param {any} res
  583. * @param {function(Error=): void} next
  584. */
  585. const errorMiddleware = (err, req, res, next) => {
  586. log.error(req.requestId, err.toString());
  587. if (res.headersSent) {
  588. next(err);
  589. return;
  590. }
  591. res.writeHead(err.status || 500, { 'Content-Type': 'application/json' });
  592. res.end(JSON.stringify({ error: err.status ? err.toString() : 'An unexpected error occurred' }));
  593. };
  594. /**
  595. * @param {array} arr
  596. * @param {number=} shift
  597. * @returns {string}
  598. */
  599. const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', ');
  600. /**
  601. * @param {string} listId
  602. * @param {any} req
  603. * @returns {Promise.<void>}
  604. */
  605. const authorizeListAccess = (listId, req) => new Promise((resolve, reject) => {
  606. const { accountId } = req;
  607. pgPool.connect((err, client, done) => {
  608. if (err) {
  609. reject();
  610. return;
  611. }
  612. client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [listId], (err, result) => {
  613. done();
  614. if (err || result.rows.length === 0 || result.rows[0].account_id !== accountId) {
  615. reject();
  616. return;
  617. }
  618. resolve();
  619. });
  620. });
  621. });
  622. /**
  623. * @param {string[]} ids
  624. * @param {any} req
  625. * @param {function(string, string): void} output
  626. * @param {undefined | function(string[], SubscriptionListener): void} attachCloseHandler
  627. * @param {'websocket' | 'eventsource'} destinationType
  628. * @param {boolean=} needsFiltering
  629. * @returns {SubscriptionListener}
  630. */
  631. const streamFrom = (ids, req, output, attachCloseHandler, destinationType, needsFiltering = false) => {
  632. const accountId = req.accountId || req.remoteAddress;
  633. log.verbose(req.requestId, `Starting stream from ${ids.join(', ')} for ${accountId}`);
  634. const transmit = (event, payload) => {
  635. // TODO: Replace "string"-based delete payloads with object payloads:
  636. const encodedPayload = typeof payload === 'object' ? JSON.stringify(payload) : payload;
  637. messagesSent.labels({ type: destinationType }).inc(1);
  638. log.silly(req.requestId, `Transmitting for ${accountId}: ${event} ${encodedPayload}`);
  639. output(event, encodedPayload);
  640. };
  641. // The listener used to process each message off the redis subscription,
  642. // message here is an object with an `event` and `payload` property. Some
  643. // events also include a queued_at value, but this is being removed shortly.
  644. /** @type {SubscriptionListener} */
  645. const listener = message => {
  646. const { event, payload } = message;
  647. // Only send local-only statuses to logged-in users
  648. if (payload.local_only && !req.accountId) {
  649. log.silly(req.requestId, `Message ${payload.id} filtered because it was local-only`);
  650. return;
  651. }
  652. // Streaming only needs to apply filtering to some channels and only to
  653. // some events. This is because majority of the filtering happens on the
  654. // Ruby on Rails side when producing the event for streaming.
  655. //
  656. // The only events that require filtering from the streaming server are
  657. // `update` and `status.update`, all other events are transmitted to the
  658. // client as soon as they're received (pass-through).
  659. //
  660. // The channels that need filtering are determined in the function
  661. // `channelNameToIds` defined below:
  662. if (!needsFiltering || (event !== 'update' && event !== 'status.update')) {
  663. transmit(event, payload);
  664. return;
  665. }
  666. // The rest of the logic from here on in this function is to handle
  667. // filtering of statuses:
  668. // Filter based on language:
  669. if (Array.isArray(req.chosenLanguages) && payload.language !== null && req.chosenLanguages.indexOf(payload.language) === -1) {
  670. log.silly(req.requestId, `Message ${payload.id} filtered by language (${payload.language})`);
  671. return;
  672. }
  673. // When the account is not logged in, it is not necessary to confirm the block or mute
  674. if (!req.accountId) {
  675. transmit(event, payload);
  676. return;
  677. }
  678. // Filter based on domain blocks, blocks, mutes, or custom filters:
  679. const targetAccountIds = [payload.account.id].concat(payload.mentions.map(item => item.id));
  680. const accountDomain = payload.account.acct.split('@')[1];
  681. // TODO: Move this logic out of the message handling loop
  682. pgPool.connect((err, client, releasePgConnection) => {
  683. if (err) {
  684. log.error(err);
  685. return;
  686. }
  687. const queries = [
  688. client.query(`SELECT 1
  689. FROM blocks
  690. WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
  691. OR (account_id = $2 AND target_account_id = $1)
  692. UNION
  693. SELECT 1
  694. FROM mutes
  695. WHERE account_id = $1
  696. AND target_account_id IN (${placeholders(targetAccountIds, 2)})`, [req.accountId, payload.account.id].concat(targetAccountIds)),
  697. ];
  698. if (accountDomain) {
  699. queries.push(client.query('SELECT 1 FROM account_domain_blocks WHERE account_id = $1 AND domain = $2', [req.accountId, accountDomain]));
  700. }
  701. if (!payload.filtered && !req.cachedFilters) {
  702. queries.push(client.query('SELECT filter.id AS id, filter.phrase AS title, filter.context AS context, filter.expires_at AS expires_at, filter.action AS filter_action, keyword.keyword AS keyword, keyword.whole_word AS whole_word FROM custom_filter_keywords keyword JOIN custom_filters filter ON keyword.custom_filter_id = filter.id WHERE filter.account_id = $1 AND (filter.expires_at IS NULL OR filter.expires_at > NOW())', [req.accountId]));
  703. }
  704. Promise.all(queries).then(values => {
  705. releasePgConnection();
  706. // Handling blocks & mutes and domain blocks: If one of those applies,
  707. // then we don't transmit the payload of the event to the client
  708. if (values[0].rows.length > 0 || (accountDomain && values[1].rows.length > 0)) {
  709. return;
  710. }
  711. // If the payload already contains the `filtered` property, it means
  712. // that filtering has been applied on the ruby on rails side, as
  713. // such, we don't need to construct or apply the filters in streaming:
  714. if (Object.prototype.hasOwnProperty.call(payload, "filtered")) {
  715. transmit(event, payload);
  716. return;
  717. }
  718. // Handling for constructing the custom filters and caching them on the request
  719. // TODO: Move this logic out of the message handling lifecycle
  720. if (!req.cachedFilters) {
  721. const filterRows = values[accountDomain ? 2 : 1].rows;
  722. req.cachedFilters = filterRows.reduce((cache, filter) => {
  723. if (cache[filter.id]) {
  724. cache[filter.id].keywords.push([filter.keyword, filter.whole_word]);
  725. } else {
  726. cache[filter.id] = {
  727. keywords: [[filter.keyword, filter.whole_word]],
  728. expires_at: filter.expires_at,
  729. filter: {
  730. id: filter.id,
  731. title: filter.title,
  732. context: filter.context,
  733. expires_at: filter.expires_at,
  734. // filter.filter_action is the value from the
  735. // custom_filters.action database column, it is an integer
  736. // representing a value in an enum defined by Ruby on Rails:
  737. //
  738. // enum { warn: 0, hide: 1 }
  739. filter_action: ['warn', 'hide'][filter.filter_action],
  740. },
  741. };
  742. }
  743. return cache;
  744. }, {});
  745. // Construct the regular expressions for the custom filters: This
  746. // needs to be done in a separate loop as the database returns one
  747. // filterRow per keyword, so we need all the keywords before
  748. // constructing the regular expression
  749. Object.keys(req.cachedFilters).forEach((key) => {
  750. req.cachedFilters[key].regexp = new RegExp(req.cachedFilters[key].keywords.map(([keyword, whole_word]) => {
  751. let expr = keyword.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
  752. if (whole_word) {
  753. if (/^[\w]/.test(expr)) {
  754. expr = `\\b${expr}`;
  755. }
  756. if (/[\w]$/.test(expr)) {
  757. expr = `${expr}\\b`;
  758. }
  759. }
  760. return expr;
  761. }).join('|'), 'i');
  762. });
  763. }
  764. // Apply cachedFilters against the payload, constructing a
  765. // `filter_results` array of FilterResult entities
  766. if (req.cachedFilters) {
  767. const status = payload;
  768. // TODO: Calculate searchableContent in Ruby on Rails:
  769. const searchableContent = ([status.spoiler_text || '', status.content].concat((status.poll && status.poll.options) ? status.poll.options.map(option => option.title) : [])).concat(status.media_attachments.map(att => att.description)).join('\n\n').replace(/<br\s*\/?>/g, '\n').replace(/<\/p><p>/g, '\n\n');
  770. const searchableTextContent = JSDOM.fragment(searchableContent).textContent;
  771. const now = new Date();
  772. const filter_results = Object.values(req.cachedFilters).reduce((results, cachedFilter) => {
  773. // Check the filter hasn't expired before applying:
  774. if (cachedFilter.expires_at !== null && cachedFilter.expires_at < now) {
  775. return results;
  776. }
  777. // Just in-case JSDOM fails to find textContent in searchableContent
  778. if (!searchableTextContent) {
  779. return results;
  780. }
  781. const keyword_matches = searchableTextContent.match(cachedFilter.regexp);
  782. if (keyword_matches) {
  783. // results is an Array of FilterResult; status_matches is always
  784. // null as we only are only applying the keyword-based custom
  785. // filters, not the status-based custom filters.
  786. // https://docs.joinmastodon.org/entities/FilterResult/
  787. results.push({
  788. filter: cachedFilter.filter,
  789. keyword_matches,
  790. status_matches: null
  791. });
  792. }
  793. return results;
  794. }, []);
  795. // Send the payload + the FilterResults as the `filtered` property
  796. // to the streaming connection. To reach this code, the `event` must
  797. // have been either `update` or `status.update`, meaning the
  798. // `payload` is a Status entity, which has a `filtered` property:
  799. //
  800. // filtered: https://docs.joinmastodon.org/entities/Status/#filtered
  801. transmit(event, {
  802. ...payload,
  803. filtered: filter_results
  804. });
  805. } else {
  806. transmit(event, payload);
  807. }
  808. }).catch(err => {
  809. log.error(err);
  810. releasePgConnection();
  811. });
  812. });
  813. };
  814. ids.forEach(id => {
  815. subscribe(`${redisPrefix}${id}`, listener);
  816. });
  817. if (typeof attachCloseHandler === 'function') {
  818. attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener);
  819. }
  820. return listener;
  821. };
  822. /**
  823. * @param {any} req
  824. * @param {any} res
  825. * @returns {function(string, string): void}
  826. */
  827. const streamToHttp = (req, res) => {
  828. const accountId = req.accountId || req.remoteAddress;
  829. const channelName = channelNameFromPath(req);
  830. connectedClients.labels({ type: 'eventsource' }).inc();
  831. // In theory we'll always have a channel name, but channelNameFromPath can return undefined:
  832. if (typeof channelName === 'string') {
  833. connectedChannels.labels({ type: 'eventsource', channel: channelName }).inc();
  834. }
  835. res.setHeader('Content-Type', 'text/event-stream');
  836. res.setHeader('Cache-Control', 'no-store');
  837. res.setHeader('Transfer-Encoding', 'chunked');
  838. res.write(':)\n');
  839. const heartbeat = setInterval(() => res.write(':thump\n'), 15000);
  840. req.on('close', () => {
  841. log.verbose(req.requestId, `Ending stream for ${accountId}`);
  842. // We decrement these counters here instead of in streamHttpEnd as in that
  843. // method we don't have knowledge of the channel names
  844. connectedClients.labels({ type: 'eventsource' }).dec();
  845. // In theory we'll always have a channel name, but channelNameFromPath can return undefined:
  846. if (typeof channelName === 'string') {
  847. connectedChannels.labels({ type: 'eventsource', channel: channelName }).dec();
  848. }
  849. clearInterval(heartbeat);
  850. });
  851. return (event, payload) => {
  852. res.write(`event: ${event}\n`);
  853. res.write(`data: ${payload}\n\n`);
  854. };
  855. };
  856. /**
  857. * @param {any} req
  858. * @param {function(): void} [closeHandler]
  859. * @returns {function(string[], SubscriptionListener): void}
  860. */
  861. const streamHttpEnd = (req, closeHandler = undefined) => (ids, listener) => {
  862. req.on('close', () => {
  863. ids.forEach(id => {
  864. unsubscribe(id, listener);
  865. });
  866. if (closeHandler) {
  867. closeHandler();
  868. }
  869. });
  870. };
  871. /**
  872. * @param {any} req
  873. * @param {any} ws
  874. * @param {string[]} streamName
  875. * @returns {function(string, string): void}
  876. */
  877. const streamToWs = (req, ws, streamName) => (event, payload) => {
  878. if (ws.readyState !== ws.OPEN) {
  879. log.error(req.requestId, 'Tried writing to closed socket');
  880. return;
  881. }
  882. ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => {
  883. if (err) {
  884. log.error(req.requestId, `Failed to send to websocket: ${err}`);
  885. }
  886. });
  887. };
  888. /**
  889. * @param {any} res
  890. */
  891. const httpNotFound = res => {
  892. res.writeHead(404, { 'Content-Type': 'application/json' });
  893. res.end(JSON.stringify({ error: 'Not found' }));
  894. };
  895. const api = express.Router();
  896. app.use(api);
  897. api.use(setRequestId);
  898. api.use(setRemoteAddress);
  899. api.use(allowCrossDomain);
  900. api.use(authenticationMiddleware);
  901. api.use(errorMiddleware);
  902. api.get('/api/v1/streaming/*', (req, res) => {
  903. channelNameToIds(req, channelNameFromPath(req), req.query).then(({ channelIds, options }) => {
  904. const onSend = streamToHttp(req, res);
  905. const onEnd = streamHttpEnd(req, subscriptionHeartbeat(channelIds));
  906. streamFrom(channelIds, req, onSend, onEnd, 'eventsource', options.needsFiltering);
  907. }).catch(err => {
  908. log.verbose(req.requestId, 'Subscription error:', err.toString());
  909. httpNotFound(res);
  910. });
  911. });
  912. const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
  913. /**
  914. * @typedef StreamParams
  915. * @property {string} [tag]
  916. * @property {string} [list]
  917. * @property {string} [only_media]
  918. */
  919. /**
  920. * @param {any} req
  921. * @returns {string[]}
  922. */
  923. const channelsForUserStream = req => {
  924. const arr = [`timeline:${req.accountId}`];
  925. if (isInScope(req, ['crypto']) && req.deviceId) {
  926. arr.push(`timeline:${req.accountId}:${req.deviceId}`);
  927. }
  928. if (isInScope(req, ['read', 'read:notifications'])) {
  929. arr.push(`timeline:${req.accountId}:notifications`);
  930. }
  931. return arr;
  932. };
  933. /**
  934. * See app/lib/ascii_folder.rb for the canon definitions
  935. * of these constants
  936. */
  937. const NON_ASCII_CHARS = 'ÀÁÂÃÄÅàáâãäåĀāĂ㥹ÇçĆćĈĉĊċČčÐðĎďĐđÈÉÊËèéêëĒēĔĕĖėĘęĚěĜĝĞğĠġĢģĤĥĦħÌÍÎÏìíîïĨĩĪīĬĭĮįİıĴĵĶķĸĹĺĻļĽľĿŀŁłÑñŃńŅņŇňʼnŊŋÒÓÔÕÖØòóôõöøŌōŎŏŐőŔŕŖŗŘřŚśŜŝŞşŠšſŢţŤťŦŧÙÚÛÜùúûüŨũŪūŬŭŮůŰűŲųŴŵÝýÿŶŷŸŹźŻżŽž';
  938. const EQUIVALENT_ASCII_CHARS = 'AAAAAAaaaaaaAaAaAaCcCcCcCcCcDdDdDdEEEEeeeeEeEeEeEeEeGgGgGgGgHhHhIIIIiiiiIiIiIiIiIiJjKkkLlLlLlLlLlNnNnNnNnnNnOOOOOOooooooOoOoOoRrRrRrSsSsSsSssTtTtTtUUUUuuuuUuUuUuUuUuUuWwYyyYyYZzZzZz';
  939. /**
  940. * @param {string} str
  941. * @returns {string}
  942. */
  943. const foldToASCII = str => {
  944. const regex = new RegExp(NON_ASCII_CHARS.split('').join('|'), 'g');
  945. return str.replace(regex, match => {
  946. const index = NON_ASCII_CHARS.indexOf(match);
  947. return EQUIVALENT_ASCII_CHARS[index];
  948. });
  949. };
  950. /**
  951. * @param {string} str
  952. * @returns {string}
  953. */
  954. const normalizeHashtag = str => {
  955. return foldToASCII(str.normalize('NFKC').toLowerCase()).replace(/[^\p{L}\p{N}_\u00b7\u200c]/gu, '');
  956. };
  957. /**
  958. * @param {any} req
  959. * @param {string} name
  960. * @param {StreamParams} params
  961. * @returns {Promise.<{ channelIds: string[], options: { needsFiltering: boolean } }>}
  962. */
  963. const channelNameToIds = (req, name, params) => new Promise((resolve, reject) => {
  964. switch (name) {
  965. case 'user':
  966. resolve({
  967. channelIds: channelsForUserStream(req),
  968. options: { needsFiltering: false },
  969. });
  970. break;
  971. case 'user:notification':
  972. resolve({
  973. channelIds: [`timeline:${req.accountId}:notifications`],
  974. options: { needsFiltering: false },
  975. });
  976. break;
  977. case 'public':
  978. resolve({
  979. channelIds: ['timeline:public'],
  980. options: { needsFiltering: true },
  981. });
  982. break;
  983. case 'public:local':
  984. resolve({
  985. channelIds: ['timeline:public:local'],
  986. options: { needsFiltering: true },
  987. });
  988. break;
  989. case 'public:remote':
  990. resolve({
  991. channelIds: ['timeline:public:remote'],
  992. options: { needsFiltering: true },
  993. });
  994. break;
  995. case 'public:media':
  996. resolve({
  997. channelIds: ['timeline:public:media'],
  998. options: { needsFiltering: true },
  999. });
  1000. break;
  1001. case 'public:local:media':
  1002. resolve({
  1003. channelIds: ['timeline:public:local:media'],
  1004. options: { needsFiltering: true },
  1005. });
  1006. break;
  1007. case 'public:remote:media':
  1008. resolve({
  1009. channelIds: ['timeline:public:remote:media'],
  1010. options: { needsFiltering: true },
  1011. });
  1012. break;
  1013. case 'direct':
  1014. resolve({
  1015. channelIds: [`timeline:direct:${req.accountId}`],
  1016. options: { needsFiltering: false },
  1017. });
  1018. break;
  1019. case 'hashtag':
  1020. if (!params.tag || params.tag.length === 0) {
  1021. reject('No tag for stream provided');
  1022. } else {
  1023. resolve({
  1024. channelIds: [`timeline:hashtag:${normalizeHashtag(params.tag)}`],
  1025. options: { needsFiltering: true },
  1026. });
  1027. }
  1028. break;
  1029. case 'hashtag:local':
  1030. if (!params.tag || params.tag.length === 0) {
  1031. reject('No tag for stream provided');
  1032. } else {
  1033. resolve({
  1034. channelIds: [`timeline:hashtag:${normalizeHashtag(params.tag)}:local`],
  1035. options: { needsFiltering: true },
  1036. });
  1037. }
  1038. break;
  1039. case 'list':
  1040. authorizeListAccess(params.list, req).then(() => {
  1041. resolve({
  1042. channelIds: [`timeline:list:${params.list}`],
  1043. options: { needsFiltering: false },
  1044. });
  1045. }).catch(() => {
  1046. reject('Not authorized to stream this list');
  1047. });
  1048. break;
  1049. default:
  1050. reject('Unknown stream type');
  1051. }
  1052. });
  1053. /**
  1054. * @param {string} channelName
  1055. * @param {StreamParams} params
  1056. * @returns {string[]}
  1057. */
  1058. const streamNameFromChannelName = (channelName, params) => {
  1059. if (channelName === 'list') {
  1060. return [channelName, params.list];
  1061. } else if (['hashtag', 'hashtag:local'].includes(channelName)) {
  1062. return [channelName, params.tag];
  1063. } else {
  1064. return [channelName];
  1065. }
  1066. };
  1067. /**
  1068. * @typedef WebSocketSession
  1069. * @property {any} socket
  1070. * @property {any} request
  1071. * @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
  1072. */
  1073. /**
  1074. * @param {WebSocketSession} session
  1075. * @param {string} channelName
  1076. * @param {StreamParams} params
  1077. * @returns {void}
  1078. */
  1079. const subscribeWebsocketToChannel = ({ socket, request, subscriptions }, channelName, params) => {
  1080. checkScopes(request, channelName).then(() => channelNameToIds(request, channelName, params)).then(({
  1081. channelIds,
  1082. options,
  1083. }) => {
  1084. if (subscriptions[channelIds.join(';')]) {
  1085. return;
  1086. }
  1087. const onSend = streamToWs(request, socket, streamNameFromChannelName(channelName, params));
  1088. const stopHeartbeat = subscriptionHeartbeat(channelIds);
  1089. const listener = streamFrom(channelIds, request, onSend, undefined, 'websocket', options.needsFiltering);
  1090. connectedChannels.labels({ type: 'websocket', channel: channelName }).inc();
  1091. subscriptions[channelIds.join(';')] = {
  1092. channelName,
  1093. listener,
  1094. stopHeartbeat,
  1095. };
  1096. }).catch(err => {
  1097. log.verbose(request.requestId, 'Subscription error:', err.toString());
  1098. socket.send(JSON.stringify({ error: err.toString() }));
  1099. });
  1100. }
  1101. const removeSubscription = (subscriptions, channelIds, request) => {
  1102. log.verbose(request.requestId, `Ending stream from ${channelIds.join(', ')} for ${request.accountId}`);
  1103. const subscription = subscriptions[channelIds.join(';')];
  1104. if (!subscription) {
  1105. return;
  1106. }
  1107. channelIds.forEach(channelId => {
  1108. unsubscribe(`${redisPrefix}${channelId}`, subscription.listener);
  1109. });
  1110. connectedChannels.labels({ type: 'websocket', channel: subscription.channelName }).dec();
  1111. subscription.stopHeartbeat();
  1112. delete subscriptions[channelIds.join(';')];
  1113. }
  1114. /**
  1115. * @param {WebSocketSession} session
  1116. * @param {string} channelName
  1117. * @param {StreamParams} params
  1118. * @returns {void}
  1119. */
  1120. const unsubscribeWebsocketFromChannel = ({ socket, request, subscriptions }, channelName, params) => {
  1121. channelNameToIds(request, channelName, params).then(({ channelIds }) => {
  1122. removeSubscription(subscriptions, channelIds, request);
  1123. }).catch(err => {
  1124. log.verbose(request.requestId, 'Unsubscribe error:', err);
  1125. // If we have a socket that is alive and open still, send the error back to the client:
  1126. // FIXME: In other parts of the code ws === socket
  1127. if (socket.isAlive && socket.readyState === socket.OPEN) {
  1128. socket.send(JSON.stringify({ error: "Error unsubscribing from channel" }));
  1129. }
  1130. });
  1131. }
  1132. /**
  1133. * @param {WebSocketSession} session
  1134. */
  1135. const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => {
  1136. const accessTokenChannelId = `timeline:access_token:${request.accessTokenId}`;
  1137. const systemChannelId = `timeline:system:${request.accountId}`;
  1138. const listener = createSystemMessageListener(request, {
  1139. onKill() {
  1140. socket.close();
  1141. },
  1142. });
  1143. subscribe(`${redisPrefix}${accessTokenChannelId}`, listener);
  1144. subscribe(`${redisPrefix}${systemChannelId}`, listener);
  1145. subscriptions[accessTokenChannelId] = {
  1146. channelName: 'system',
  1147. listener,
  1148. stopHeartbeat: () => {
  1149. },
  1150. };
  1151. subscriptions[systemChannelId] = {
  1152. channelName: 'system',
  1153. listener,
  1154. stopHeartbeat: () => {
  1155. },
  1156. };
  1157. connectedChannels.labels({ type: 'websocket', channel: 'system' }).inc(2);
  1158. };
  1159. /**
  1160. * @param {string|string[]} arrayOrString
  1161. * @returns {string}
  1162. */
  1163. const firstParam = arrayOrString => {
  1164. if (Array.isArray(arrayOrString)) {
  1165. return arrayOrString[0];
  1166. } else {
  1167. return arrayOrString;
  1168. }
  1169. };
  1170. wss.on('connection', (ws, req) => {
  1171. // Note: url.parse could throw, which would terminate the connection, so we
  1172. // increment the connected clients metric straight away when we establish
  1173. // the connection, without waiting:
  1174. connectedClients.labels({ type: 'websocket' }).inc();
  1175. // Setup request properties:
  1176. req.requestId = uuid.v4();
  1177. req.remoteAddress = ws._socket.remoteAddress;
  1178. // Setup connection keep-alive state:
  1179. ws.isAlive = true;
  1180. ws.on('pong', () => {
  1181. ws.isAlive = true;
  1182. });
  1183. /**
  1184. * @type {WebSocketSession}
  1185. */
  1186. const session = {
  1187. socket: ws,
  1188. request: req,
  1189. subscriptions: {},
  1190. };
  1191. ws.on('close', function onWebsocketClose() {
  1192. const subscriptions = Object.keys(session.subscriptions);
  1193. subscriptions.forEach(channelIds => {
  1194. removeSubscription(session.subscriptions, channelIds.split(';'), req)
  1195. });
  1196. // Decrement the metrics for connected clients:
  1197. connectedClients.labels({ type: 'websocket' }).dec();
  1198. // ensure garbage collection:
  1199. session.socket = null;
  1200. session.request = null;
  1201. session.subscriptions = {};
  1202. });
  1203. // Note: immediately after the `error` event is emitted, the `close` event
  1204. // is emitted. As such, all we need to do is log the error here.
  1205. ws.on('error', (err) => {
  1206. log.error('websocket', err.toString());
  1207. });
  1208. ws.on('message', (data, isBinary) => {
  1209. if (isBinary) {
  1210. log.warn('websocket', 'Received binary data, closing connection');
  1211. ws.close(1003, 'The mastodon streaming server does not support binary messages');
  1212. return;
  1213. }
  1214. const message = data.toString('utf8');
  1215. const json = parseJSON(message, session.request);
  1216. if (!json) return;
  1217. const { type, stream, ...params } = json;
  1218. if (type === 'subscribe') {
  1219. subscribeWebsocketToChannel(session, firstParam(stream), params);
  1220. } else if (type === 'unsubscribe') {
  1221. unsubscribeWebsocketFromChannel(session, firstParam(stream), params);
  1222. } else {
  1223. // Unknown action type
  1224. }
  1225. });
  1226. subscribeWebsocketToSystemChannel(session);
  1227. // Parse the URL for the connection arguments (if supplied), url.parse can throw:
  1228. const location = req.url && url.parse(req.url, true);
  1229. if (location && location.query.stream) {
  1230. subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
  1231. }
  1232. });
  1233. setInterval(() => {
  1234. wss.clients.forEach(ws => {
  1235. if (ws.isAlive === false) {
  1236. ws.terminate();
  1237. return;
  1238. }
  1239. ws.isAlive = false;
  1240. ws.ping('', false);
  1241. });
  1242. }, 30000);
  1243. attachServerWithConfig(server, address => {
  1244. log.warn(`Streaming API now listening on ${address}`);
  1245. });
  1246. const onExit = () => {
  1247. server.close();
  1248. process.exit(0);
  1249. };
  1250. const onError = (err) => {
  1251. log.error(err);
  1252. server.close();
  1253. process.exit(0);
  1254. };
  1255. process.on('SIGINT', onExit);
  1256. process.on('SIGTERM', onExit);
  1257. process.on('exit', onExit);
  1258. process.on('uncaughtException', onError);
  1259. };
  1260. /**
  1261. * @param {any} server
  1262. * @param {function(string): void} [onSuccess]
  1263. */
  1264. const attachServerWithConfig = (server, onSuccess) => {
  1265. if (process.env.SOCKET || process.env.PORT && isNaN(+process.env.PORT)) {
  1266. server.listen(process.env.SOCKET || process.env.PORT, () => {
  1267. if (onSuccess) {
  1268. fs.chmodSync(server.address(), 0o666);
  1269. onSuccess(server.address());
  1270. }
  1271. });
  1272. } else {
  1273. server.listen(+process.env.PORT || 4000, process.env.BIND || '127.0.0.1', () => {
  1274. if (onSuccess) {
  1275. onSuccess(`${server.address().address}:${server.address().port}`);
  1276. }
  1277. });
  1278. }
  1279. };
  1280. startServer();