index.js 32 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220
  1. // @ts-check
  2. const os = require('os');
  3. const throng = require('throng');
  4. const dotenv = require('dotenv');
  5. const express = require('express');
  6. const http = require('http');
  7. const redis = require('redis');
  8. const pg = require('pg');
  9. const log = require('npmlog');
  10. const url = require('url');
  11. const uuid = require('uuid');
  12. const fs = require('fs');
  13. const WebSocket = require('ws');
  14. const env = process.env.NODE_ENV || 'development';
  15. const alwaysRequireAuth = process.env.LIMITED_FEDERATION_MODE === 'true' || process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true';
  16. dotenv.config({
  17. path: env === 'production' ? '.env.production' : '.env',
  18. });
  19. log.level = process.env.LOG_LEVEL || 'verbose';
  20. /**
  21. * @param {string} dbUrl
  22. * @return {Object.<string, any>}
  23. */
  24. const dbUrlToConfig = (dbUrl) => {
  25. if (!dbUrl) {
  26. return {};
  27. }
  28. const params = url.parse(dbUrl, true);
  29. const config = {};
  30. if (params.auth) {
  31. [config.user, config.password] = params.auth.split(':');
  32. }
  33. if (params.hostname) {
  34. config.host = params.hostname;
  35. }
  36. if (params.port) {
  37. config.port = params.port;
  38. }
  39. if (params.pathname) {
  40. config.database = params.pathname.split('/')[1];
  41. }
  42. const ssl = params.query && params.query.ssl;
  43. if (ssl && ssl === 'true' || ssl === '1') {
  44. config.ssl = true;
  45. }
  46. return config;
  47. };
  48. /**
  49. * @param {Object.<string, any>} defaultConfig
  50. * @param {string} redisUrl
  51. */
  52. const redisUrlToClient = async (defaultConfig, redisUrl) => {
  53. const config = defaultConfig;
  54. let client;
  55. if (!redisUrl) {
  56. client = redis.createClient(config);
  57. } else if (redisUrl.startsWith('unix://')) {
  58. client = redis.createClient(Object.assign(config, {
  59. socket: {
  60. path: redisUrl.slice(7),
  61. },
  62. }));
  63. } else {
  64. client = redis.createClient(Object.assign(config, {
  65. url: redisUrl,
  66. }));
  67. }
  68. client.on('error', (err) => log.error('Redis Client Error!', err));
  69. await client.connect();
  70. return client;
  71. };
  72. const numWorkers = +process.env.STREAMING_CLUSTER_NUM || (env === 'development' ? 1 : Math.max(os.cpus().length - 1, 1));
  73. /**
  74. * Attempts to safely parse a string as JSON, used when both receiving a message
  75. * from redis and when receiving a message from a client over a websocket
  76. * connection, this is why it accepts a `req` argument.
  77. * @param {string} json
  78. * @param {any?} req
  79. * @returns {Object.<string, any>|null}
  80. */
  81. const parseJSON = (json, req) => {
  82. try {
  83. return JSON.parse(json);
  84. } catch (err) {
  85. /* FIXME: This logging isn't great, and should probably be done at the
  86. * call-site of parseJSON, not in the method, but this would require changing
  87. * the signature of parseJSON to return something akin to a Result type:
  88. * [Error|null, null|Object<string,any}], and then handling the error
  89. * scenarios.
  90. */
  91. if (req) {
  92. if (req.accountId) {
  93. log.warn(req.requestId, `Error parsing message from user ${req.accountId}: ${err}`);
  94. } else {
  95. log.silly(req.requestId, `Error parsing message from ${req.remoteAddress}: ${err}`);
  96. }
  97. } else {
  98. log.warn(`Error parsing message from redis: ${err}`);
  99. }
  100. return null;
  101. }
  102. };
  103. const startMaster = () => {
  104. if (!process.env.SOCKET && process.env.PORT && isNaN(+process.env.PORT)) {
  105. log.warn('UNIX domain socket is now supported by using SOCKET. Please migrate from PORT hack.');
  106. }
  107. log.warn(`Starting streaming API server master with ${numWorkers} workers`);
  108. };
  109. const startWorker = async (workerId) => {
  110. log.warn(`Starting worker ${workerId}`);
  111. const pgConfigs = {
  112. development: {
  113. user: process.env.DB_USER || pg.defaults.user,
  114. password: process.env.DB_PASS || pg.defaults.password,
  115. database: process.env.DB_NAME || 'mastodon_development',
  116. host: process.env.DB_HOST || pg.defaults.host,
  117. port: process.env.DB_PORT || pg.defaults.port,
  118. max: 10,
  119. },
  120. production: {
  121. user: process.env.DB_USER || 'mastodon',
  122. password: process.env.DB_PASS || '',
  123. database: process.env.DB_NAME || 'mastodon_production',
  124. host: process.env.DB_HOST || 'localhost',
  125. port: process.env.DB_PORT || 5432,
  126. max: 10,
  127. },
  128. };
  129. if (!!process.env.DB_SSLMODE && process.env.DB_SSLMODE !== 'disable') {
  130. pgConfigs.development.ssl = true;
  131. pgConfigs.production.ssl = true;
  132. }
  133. const app = express();
  134. app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
  135. const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL)));
  136. const server = http.createServer(app);
  137. const redisNamespace = process.env.REDIS_NAMESPACE || null;
  138. const redisParams = {
  139. socket: {
  140. host: process.env.REDIS_HOST || '127.0.0.1',
  141. port: process.env.REDIS_PORT || 6379,
  142. },
  143. database: process.env.REDIS_DB || 0,
  144. password: process.env.REDIS_PASSWORD || undefined,
  145. };
  146. if (redisNamespace) {
  147. redisParams.namespace = redisNamespace;
  148. }
  149. const redisPrefix = redisNamespace ? `${redisNamespace}:` : '';
  150. /**
  151. * @type {Object.<string, Array.<function(Object<string, any>): void>>}
  152. */
  153. const subs = {};
  154. const redisSubscribeClient = await redisUrlToClient(redisParams, process.env.REDIS_URL);
  155. const redisClient = await redisUrlToClient(redisParams, process.env.REDIS_URL);
  156. /**
  157. * @param {string[]} channels
  158. * @return {function(): void}
  159. */
  160. const subscriptionHeartbeat = channels => {
  161. const interval = 6 * 60;
  162. const tellSubscribed = () => {
  163. channels.forEach(channel => redisClient.set(`${redisPrefix}subscribed:${channel}`, '1', 'EX', interval * 3));
  164. };
  165. tellSubscribed();
  166. const heartbeat = setInterval(tellSubscribed, interval * 1000);
  167. return () => {
  168. clearInterval(heartbeat);
  169. };
  170. };
  171. /**
  172. * @param {string} message
  173. * @param {string} channel
  174. */
  175. const onRedisMessage = (message, channel) => {
  176. const callbacks = subs[channel];
  177. log.silly(`New message on channel ${channel}`);
  178. if (!callbacks) {
  179. return;
  180. }
  181. const json = parseJSON(message, null);
  182. if (!json) return;
  183. callbacks.forEach(callback => callback(json));
  184. };
  185. /**
  186. * @callback SubscriptionListener
  187. * @param {ReturnType<parseJSON>} json of the message
  188. * @returns void
  189. */
  190. /**
  191. * @param {string} channel
  192. * @param {SubscriptionListener} callback
  193. */
  194. const subscribe = (channel, callback) => {
  195. log.silly(`Adding listener for ${channel}`);
  196. subs[channel] = subs[channel] || [];
  197. if (subs[channel].length === 0) {
  198. log.verbose(`Subscribe ${channel}`);
  199. redisSubscribeClient.subscribe(channel, onRedisMessage);
  200. }
  201. subs[channel].push(callback);
  202. };
  203. /**
  204. * @param {string} channel
  205. * @param {SubscriptionListener} callback
  206. */
  207. const unsubscribe = (channel, callback) => {
  208. log.silly(`Removing listener for ${channel}`);
  209. if (!subs[channel]) {
  210. return;
  211. }
  212. subs[channel] = subs[channel].filter(item => item !== callback);
  213. if (subs[channel].length === 0) {
  214. log.verbose(`Unsubscribe ${channel}`);
  215. redisSubscribeClient.unsubscribe(channel);
  216. delete subs[channel];
  217. }
  218. };
  219. const FALSE_VALUES = [
  220. false,
  221. 0,
  222. '0',
  223. 'f',
  224. 'F',
  225. 'false',
  226. 'FALSE',
  227. 'off',
  228. 'OFF',
  229. ];
  230. /**
  231. * @param {any} value
  232. * @return {boolean}
  233. */
  234. const isTruthy = value =>
  235. value && !FALSE_VALUES.includes(value);
  236. /**
  237. * @param {any} req
  238. * @param {any} res
  239. * @param {function(Error=): void}
  240. */
  241. const allowCrossDomain = (req, res, next) => {
  242. res.header('Access-Control-Allow-Origin', '*');
  243. res.header('Access-Control-Allow-Headers', 'Authorization, Accept, Cache-Control');
  244. res.header('Access-Control-Allow-Methods', 'GET, OPTIONS');
  245. next();
  246. };
  247. /**
  248. * @param {any} req
  249. * @param {any} res
  250. * @param {function(Error=): void}
  251. */
  252. const setRequestId = (req, res, next) => {
  253. req.requestId = uuid.v4();
  254. res.header('X-Request-Id', req.requestId);
  255. next();
  256. };
  257. /**
  258. * @param {any} req
  259. * @param {any} res
  260. * @param {function(Error=): void}
  261. */
  262. const setRemoteAddress = (req, res, next) => {
  263. req.remoteAddress = req.connection.remoteAddress;
  264. next();
  265. };
  266. /**
  267. * @param {any} req
  268. * @param {string[]} necessaryScopes
  269. * @return {boolean}
  270. */
  271. const isInScope = (req, necessaryScopes) =>
  272. req.scopes.some(scope => necessaryScopes.includes(scope));
  273. /**
  274. * @param {string} token
  275. * @param {any} req
  276. * @return {Promise.<void>}
  277. */
  278. const accountFromToken = (token, req) => new Promise((resolve, reject) => {
  279. pgPool.connect((err, client, done) => {
  280. if (err) {
  281. reject(err);
  282. return;
  283. }
  284. client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
  285. done();
  286. if (err) {
  287. reject(err);
  288. return;
  289. }
  290. if (result.rows.length === 0) {
  291. err = new Error('Invalid access token');
  292. err.status = 401;
  293. reject(err);
  294. return;
  295. }
  296. req.accessTokenId = result.rows[0].id;
  297. req.scopes = result.rows[0].scopes.split(' ');
  298. req.accountId = result.rows[0].account_id;
  299. req.chosenLanguages = result.rows[0].chosen_languages;
  300. req.deviceId = result.rows[0].device_id;
  301. resolve();
  302. });
  303. });
  304. });
  305. /**
  306. * @param {any} req
  307. * @param {boolean=} required
  308. * @return {Promise.<void>}
  309. */
  310. const accountFromRequest = (req, required = true) => new Promise((resolve, reject) => {
  311. const authorization = req.headers.authorization;
  312. const location = url.parse(req.url, true);
  313. const accessToken = location.query.access_token || req.headers['sec-websocket-protocol'];
  314. if (!authorization && !accessToken) {
  315. if (required) {
  316. const err = new Error('Missing access token');
  317. err.status = 401;
  318. reject(err);
  319. return;
  320. } else {
  321. resolve();
  322. return;
  323. }
  324. }
  325. const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken;
  326. resolve(accountFromToken(token, req));
  327. });
  328. /**
  329. * @param {any} req
  330. * @returns {string|undefined}
  331. */
  332. const channelNameFromPath = req => {
  333. const { path, query } = req;
  334. const onlyMedia = isTruthy(query.only_media);
  335. switch (path) {
  336. case '/api/v1/streaming/user':
  337. return 'user';
  338. case '/api/v1/streaming/user/notification':
  339. return 'user:notification';
  340. case '/api/v1/streaming/public':
  341. return onlyMedia ? 'public:media' : 'public';
  342. case '/api/v1/streaming/public/local':
  343. return onlyMedia ? 'public:local:media' : 'public:local';
  344. case '/api/v1/streaming/public/remote':
  345. return onlyMedia ? 'public:remote:media' : 'public:remote';
  346. case '/api/v1/streaming/hashtag':
  347. return 'hashtag';
  348. case '/api/v1/streaming/hashtag/local':
  349. return 'hashtag:local';
  350. case '/api/v1/streaming/direct':
  351. return 'direct';
  352. case '/api/v1/streaming/list':
  353. return 'list';
  354. default:
  355. return undefined;
  356. }
  357. };
  358. const PUBLIC_CHANNELS = [
  359. 'public',
  360. 'public:media',
  361. 'public:local',
  362. 'public:local:media',
  363. 'public:remote',
  364. 'public:remote:media',
  365. 'hashtag',
  366. 'hashtag:local',
  367. ];
  368. /**
  369. * @param {any} req
  370. * @param {string} channelName
  371. * @return {Promise.<void>}
  372. */
  373. const checkScopes = (req, channelName) => new Promise((resolve, reject) => {
  374. log.silly(req.requestId, `Checking OAuth scopes for ${channelName}`);
  375. // When accessing public channels, no scopes are needed
  376. if (PUBLIC_CHANNELS.includes(channelName)) {
  377. resolve();
  378. return;
  379. }
  380. // The `read` scope has the highest priority, if the token has it
  381. // then it can access all streams
  382. const requiredScopes = ['read'];
  383. // When accessing specifically the notifications stream,
  384. // we need a read:notifications, while in all other cases,
  385. // we can allow access with read:statuses. Mind that the
  386. // user stream will not contain notifications unless
  387. // the token has either read or read:notifications scope
  388. // as well, this is handled separately.
  389. if (channelName === 'user:notification') {
  390. requiredScopes.push('read:notifications');
  391. } else {
  392. requiredScopes.push('read:statuses');
  393. }
  394. if (req.scopes && requiredScopes.some(requiredScope => req.scopes.includes(requiredScope))) {
  395. resolve();
  396. return;
  397. }
  398. const err = new Error('Access token does not cover required scopes');
  399. err.status = 401;
  400. reject(err);
  401. });
  402. /**
  403. * @param {any} info
  404. * @param {function(boolean, number, string): void} callback
  405. */
  406. const wsVerifyClient = (info, callback) => {
  407. // When verifying the websockets connection, we no longer pre-emptively
  408. // check OAuth scopes and drop the connection if they're missing. We only
  409. // drop the connection if access without token is not allowed by environment
  410. // variables. OAuth scope checks are moved to the point of subscription
  411. // to a specific stream.
  412. accountFromRequest(info.req, alwaysRequireAuth).then(() => {
  413. callback(true, undefined, undefined);
  414. }).catch(err => {
  415. log.error(info.req.requestId, err.toString());
  416. callback(false, 401, 'Unauthorized');
  417. });
  418. };
  419. /**
  420. * @typedef SystemMessageHandlers
  421. * @property {function(): void} onKill
  422. */
  423. /**
  424. * @param {any} req
  425. * @param {SystemMessageHandlers} eventHandlers
  426. * @returns {function(object): void}
  427. */
  428. const createSystemMessageListener = (req, eventHandlers) => {
  429. return message => {
  430. const { event } = message;
  431. log.silly(req.requestId, `System message for ${req.accountId}: ${event}`);
  432. if (event === 'kill') {
  433. log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`);
  434. eventHandlers.onKill();
  435. }
  436. };
  437. };
  438. /**
  439. * @param {any} req
  440. * @param {any} res
  441. */
  442. const subscribeHttpToSystemChannel = (req, res) => {
  443. const systemChannelId = `timeline:access_token:${req.accessTokenId}`;
  444. const listener = createSystemMessageListener(req, {
  445. onKill() {
  446. res.end();
  447. },
  448. });
  449. res.on('close', () => {
  450. unsubscribe(`${redisPrefix}${systemChannelId}`, listener);
  451. });
  452. subscribe(`${redisPrefix}${systemChannelId}`, listener);
  453. };
  454. /**
  455. * @param {any} req
  456. * @param {any} res
  457. * @param {function(Error=): void} next
  458. */
  459. const authenticationMiddleware = (req, res, next) => {
  460. if (req.method === 'OPTIONS') {
  461. next();
  462. return;
  463. }
  464. accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => {
  465. subscribeHttpToSystemChannel(req, res);
  466. }).then(() => {
  467. next();
  468. }).catch(err => {
  469. next(err);
  470. });
  471. };
  472. /**
  473. * @param {Error} err
  474. * @param {any} req
  475. * @param {any} res
  476. * @param {function(Error=): void} next
  477. */
  478. const errorMiddleware = (err, req, res, next) => {
  479. log.error(req.requestId, err.toString());
  480. if (res.headersSent) {
  481. next(err);
  482. return;
  483. }
  484. res.writeHead(err.status || 500, { 'Content-Type': 'application/json' });
  485. res.end(JSON.stringify({ error: err.status ? err.toString() : 'An unexpected error occurred' }));
  486. };
  487. /**
  488. * @param {array} arr
  489. * @param {number=} shift
  490. * @return {string}
  491. */
  492. const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', ');
  493. /**
  494. * @param {string} listId
  495. * @param {any} req
  496. * @return {Promise.<void>}
  497. */
  498. const authorizeListAccess = (listId, req) => new Promise((resolve, reject) => {
  499. const { accountId } = req;
  500. pgPool.connect((err, client, done) => {
  501. if (err) {
  502. reject();
  503. return;
  504. }
  505. client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [listId], (err, result) => {
  506. done();
  507. if (err || result.rows.length === 0 || result.rows[0].account_id !== accountId) {
  508. reject();
  509. return;
  510. }
  511. resolve();
  512. });
  513. });
  514. });
  515. /**
  516. * @param {string[]} ids
  517. * @param {any} req
  518. * @param {function(string, string): void} output
  519. * @param {undefined | function(string[], SubscriptionListener): void} attachCloseHandler
  520. * @param {boolean=} needsFiltering
  521. * @returns {SubscriptionListener}
  522. */
  523. const streamFrom = (ids, req, output, attachCloseHandler, needsFiltering = false) => {
  524. const accountId = req.accountId || req.remoteAddress;
  525. log.verbose(req.requestId, `Starting stream from ${ids.join(', ')} for ${accountId}`);
  526. // Currently message is of type string, soon it'll be Record<string, any>
  527. const listener = message => {
  528. const { event, payload, queued_at } = message;
  529. const transmit = () => {
  530. const now = new Date().getTime();
  531. const delta = now - queued_at;
  532. const encodedPayload = typeof payload === 'object' ? JSON.stringify(payload) : payload;
  533. log.silly(req.requestId, `Transmitting for ${accountId}: ${event} ${encodedPayload} Delay: ${delta}ms`);
  534. output(event, encodedPayload);
  535. };
  536. // Only send local-only statuses to logged-in users
  537. if (payload.local_only && !req.accountId) {
  538. log.silly(req.requestId, `Message ${payload.id} filtered because it was local-only`);
  539. return;
  540. }
  541. // Only messages that may require filtering are statuses, since notifications
  542. // are already personalized and deletes do not matter
  543. if (!needsFiltering || event !== 'update') {
  544. transmit();
  545. return;
  546. }
  547. const unpackedPayload = payload;
  548. const targetAccountIds = [unpackedPayload.account.id].concat(unpackedPayload.mentions.map(item => item.id));
  549. const accountDomain = unpackedPayload.account.acct.split('@')[1];
  550. if (Array.isArray(req.chosenLanguages) && unpackedPayload.language !== null && req.chosenLanguages.indexOf(unpackedPayload.language) === -1) {
  551. log.silly(req.requestId, `Message ${unpackedPayload.id} filtered by language (${unpackedPayload.language})`);
  552. return;
  553. }
  554. // When the account is not logged in, it is not necessary to confirm the block or mute
  555. if (!req.accountId) {
  556. transmit();
  557. return;
  558. }
  559. pgPool.connect((err, client, done) => {
  560. if (err) {
  561. log.error(err);
  562. return;
  563. }
  564. const queries = [
  565. client.query(`SELECT 1
  566. FROM blocks
  567. WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
  568. OR (account_id = $2 AND target_account_id = $1)
  569. UNION
  570. SELECT 1
  571. FROM mutes
  572. WHERE account_id = $1
  573. AND target_account_id IN (${placeholders(targetAccountIds, 2)})`, [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),
  574. ];
  575. if (accountDomain) {
  576. queries.push(client.query('SELECT 1 FROM account_domain_blocks WHERE account_id = $1 AND domain = $2', [req.accountId, accountDomain]));
  577. }
  578. Promise.all(queries).then(values => {
  579. done();
  580. if (values[0].rows.length > 0 || (values.length > 1 && values[1].rows.length > 0)) {
  581. return;
  582. }
  583. transmit();
  584. }).catch(err => {
  585. done();
  586. log.error(err);
  587. });
  588. });
  589. };
  590. ids.forEach(id => {
  591. subscribe(`${redisPrefix}${id}`, listener);
  592. });
  593. if (typeof attachCloseHandler === 'function') {
  594. attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener);
  595. }
  596. return listener;
  597. };
  598. /**
  599. * @param {any} req
  600. * @param {any} res
  601. * @return {function(string, string): void}
  602. */
  603. const streamToHttp = (req, res) => {
  604. const accountId = req.accountId || req.remoteAddress;
  605. res.setHeader('Content-Type', 'text/event-stream');
  606. res.setHeader('Cache-Control', 'no-store');
  607. res.setHeader('Transfer-Encoding', 'chunked');
  608. res.write(':)\n');
  609. const heartbeat = setInterval(() => res.write(':thump\n'), 15000);
  610. req.on('close', () => {
  611. log.verbose(req.requestId, `Ending stream for ${accountId}`);
  612. clearInterval(heartbeat);
  613. });
  614. return (event, payload) => {
  615. res.write(`event: ${event}\n`);
  616. res.write(`data: ${payload}\n\n`);
  617. };
  618. };
  619. /**
  620. * @param {any} req
  621. * @param {function(): void} [closeHandler]
  622. * @returns {function(string[], SubscriptionListener): void}
  623. */
  624. const streamHttpEnd = (req, closeHandler = undefined) => (ids, listener) => {
  625. req.on('close', () => {
  626. ids.forEach(id => {
  627. unsubscribe(id, listener);
  628. });
  629. if (closeHandler) {
  630. closeHandler();
  631. }
  632. });
  633. };
  634. /**
  635. * @param {any} req
  636. * @param {any} ws
  637. * @param {string[]} streamName
  638. * @return {function(string, string): void}
  639. */
  640. const streamToWs = (req, ws, streamName) => (event, payload) => {
  641. if (ws.readyState !== ws.OPEN) {
  642. log.error(req.requestId, 'Tried writing to closed socket');
  643. return;
  644. }
  645. ws.send(JSON.stringify({ stream: streamName, event, payload }));
  646. };
  647. /**
  648. * @param {any} res
  649. */
  650. const httpNotFound = res => {
  651. res.writeHead(404, { 'Content-Type': 'application/json' });
  652. res.end(JSON.stringify({ error: 'Not found' }));
  653. };
  654. app.use(setRequestId);
  655. app.use(setRemoteAddress);
  656. app.use(allowCrossDomain);
  657. app.get('/api/v1/streaming/health', (req, res) => {
  658. res.writeHead(200, { 'Content-Type': 'text/plain' });
  659. res.end('OK');
  660. });
  661. app.use(authenticationMiddleware);
  662. app.use(errorMiddleware);
  663. app.get('/api/v1/streaming/*', (req, res) => {
  664. channelNameToIds(req, channelNameFromPath(req), req.query).then(({ channelIds, options }) => {
  665. const onSend = streamToHttp(req, res);
  666. const onEnd = streamHttpEnd(req, subscriptionHeartbeat(channelIds));
  667. streamFrom(channelIds, req, onSend, onEnd, options.needsFiltering);
  668. }).catch(err => {
  669. log.verbose(req.requestId, 'Subscription error:', err.toString());
  670. httpNotFound(res);
  671. });
  672. });
  673. const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
  674. /**
  675. * @typedef StreamParams
  676. * @property {string} [tag]
  677. * @property {string} [list]
  678. * @property {string} [only_media]
  679. */
  680. /**
  681. * @param {any} req
  682. * @return {string[]}
  683. */
  684. const channelsForUserStream = req => {
  685. const arr = [`timeline:${req.accountId}`];
  686. if (isInScope(req, ['crypto']) && req.deviceId) {
  687. arr.push(`timeline:${req.accountId}:${req.deviceId}`);
  688. }
  689. if (isInScope(req, ['read', 'read:notifications'])) {
  690. arr.push(`timeline:${req.accountId}:notifications`);
  691. }
  692. return arr;
  693. };
  694. /**
  695. * @param {any} req
  696. * @param {string} name
  697. * @param {StreamParams} params
  698. * @return {Promise.<{ channelIds: string[], options: { needsFiltering: boolean } }>}
  699. */
  700. const channelNameToIds = (req, name, params) => new Promise((resolve, reject) => {
  701. switch (name) {
  702. case 'user':
  703. resolve({
  704. channelIds: channelsForUserStream(req),
  705. options: { needsFiltering: false },
  706. });
  707. break;
  708. case 'user:notification':
  709. resolve({
  710. channelIds: [`timeline:${req.accountId}:notifications`],
  711. options: { needsFiltering: false },
  712. });
  713. break;
  714. case 'public':
  715. resolve({
  716. channelIds: ['timeline:public'],
  717. options: { needsFiltering: true },
  718. });
  719. break;
  720. case 'public:local':
  721. resolve({
  722. channelIds: ['timeline:public:local'],
  723. options: { needsFiltering: true },
  724. });
  725. break;
  726. case 'public:remote':
  727. resolve({
  728. channelIds: ['timeline:public:remote'],
  729. options: { needsFiltering: true },
  730. });
  731. break;
  732. case 'public:media':
  733. resolve({
  734. channelIds: ['timeline:public:media'],
  735. options: { needsFiltering: true },
  736. });
  737. break;
  738. case 'public:local:media':
  739. resolve({
  740. channelIds: ['timeline:public:local:media'],
  741. options: { needsFiltering: true },
  742. });
  743. break;
  744. case 'public:remote:media':
  745. resolve({
  746. channelIds: ['timeline:public:remote:media'],
  747. options: { needsFiltering: true },
  748. });
  749. break;
  750. case 'direct':
  751. resolve({
  752. channelIds: [`timeline:direct:${req.accountId}`],
  753. options: { needsFiltering: false },
  754. });
  755. break;
  756. case 'hashtag':
  757. if (!params.tag || params.tag.length === 0) {
  758. reject('No tag for stream provided');
  759. } else {
  760. resolve({
  761. channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}`],
  762. options: { needsFiltering: true },
  763. });
  764. }
  765. break;
  766. case 'hashtag:local':
  767. if (!params.tag || params.tag.length === 0) {
  768. reject('No tag for stream provided');
  769. } else {
  770. resolve({
  771. channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}:local`],
  772. options: { needsFiltering: true },
  773. });
  774. }
  775. break;
  776. case 'list':
  777. authorizeListAccess(params.list, req).then(() => {
  778. resolve({
  779. channelIds: [`timeline:list:${params.list}`],
  780. options: { needsFiltering: false },
  781. });
  782. }).catch(() => {
  783. reject('Not authorized to stream this list');
  784. });
  785. break;
  786. default:
  787. reject('Unknown stream type');
  788. }
  789. });
  790. /**
  791. * @param {string} channelName
  792. * @param {StreamParams} params
  793. * @return {string[]}
  794. */
  795. const streamNameFromChannelName = (channelName, params) => {
  796. if (channelName === 'list') {
  797. return [channelName, params.list];
  798. } else if (['hashtag', 'hashtag:local'].includes(channelName)) {
  799. return [channelName, params.tag];
  800. } else {
  801. return [channelName];
  802. }
  803. };
  804. /**
  805. * @typedef WebSocketSession
  806. * @property {any} socket
  807. * @property {any} request
  808. * @property {Object.<string, { listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
  809. */
  810. /**
  811. * @param {WebSocketSession} session
  812. * @param {string} channelName
  813. * @param {StreamParams} params
  814. */
  815. const subscribeWebsocketToChannel = ({ socket, request, subscriptions }, channelName, params) =>
  816. checkScopes(request, channelName).then(() => channelNameToIds(request, channelName, params)).then(({
  817. channelIds,
  818. options,
  819. }) => {
  820. if (subscriptions[channelIds.join(';')]) {
  821. return;
  822. }
  823. const onSend = streamToWs(request, socket, streamNameFromChannelName(channelName, params));
  824. const stopHeartbeat = subscriptionHeartbeat(channelIds);
  825. const listener = streamFrom(channelIds, request, onSend, undefined, options.needsFiltering);
  826. subscriptions[channelIds.join(';')] = {
  827. listener,
  828. stopHeartbeat,
  829. };
  830. }).catch(err => {
  831. log.verbose(request.requestId, 'Subscription error:', err.toString());
  832. socket.send(JSON.stringify({ error: err.toString() }));
  833. });
  834. /**
  835. * @param {WebSocketSession} session
  836. * @param {string} channelName
  837. * @param {StreamParams} params
  838. */
  839. const unsubscribeWebsocketFromChannel = ({ socket, request, subscriptions }, channelName, params) =>
  840. channelNameToIds(request, channelName, params).then(({ channelIds }) => {
  841. log.verbose(request.requestId, `Ending stream from ${channelIds.join(', ')} for ${request.accountId}`);
  842. const subscription = subscriptions[channelIds.join(';')];
  843. if (!subscription) {
  844. return;
  845. }
  846. const { listener, stopHeartbeat } = subscription;
  847. channelIds.forEach(channelId => {
  848. unsubscribe(`${redisPrefix}${channelId}`, listener);
  849. });
  850. stopHeartbeat();
  851. delete subscriptions[channelIds.join(';')];
  852. }).catch(err => {
  853. log.verbose(request.requestId, 'Unsubscription error:', err);
  854. socket.send(JSON.stringify({ error: err.toString() }));
  855. });
  856. /**
  857. * @param {WebSocketSession} session
  858. */
  859. const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => {
  860. const systemChannelId = `timeline:access_token:${request.accessTokenId}`;
  861. const listener = createSystemMessageListener(request, {
  862. onKill() {
  863. socket.close();
  864. },
  865. });
  866. subscribe(`${redisPrefix}${systemChannelId}`, listener);
  867. subscriptions[systemChannelId] = {
  868. listener,
  869. stopHeartbeat: () => {
  870. },
  871. };
  872. };
  873. /**
  874. * @param {string|string[]} arrayOrString
  875. * @return {string}
  876. */
  877. const firstParam = arrayOrString => {
  878. if (Array.isArray(arrayOrString)) {
  879. return arrayOrString[0];
  880. } else {
  881. return arrayOrString;
  882. }
  883. };
  884. wss.on('connection', (ws, req) => {
  885. const location = url.parse(req.url, true);
  886. req.requestId = uuid.v4();
  887. req.remoteAddress = ws._socket.remoteAddress;
  888. ws.isAlive = true;
  889. ws.on('pong', () => {
  890. ws.isAlive = true;
  891. });
  892. /**
  893. * @type {WebSocketSession}
  894. */
  895. const session = {
  896. socket: ws,
  897. request: req,
  898. subscriptions: {},
  899. };
  900. const onEnd = () => {
  901. const keys = Object.keys(session.subscriptions);
  902. keys.forEach(channelIds => {
  903. const { listener, stopHeartbeat } = session.subscriptions[channelIds];
  904. channelIds.split(';').forEach(channelId => {
  905. unsubscribe(`${redisPrefix}${channelId}`, listener);
  906. });
  907. stopHeartbeat();
  908. });
  909. };
  910. ws.on('close', onEnd);
  911. ws.on('error', onEnd);
  912. ws.on('message', (data, isBinary) => {
  913. if (isBinary) {
  914. log.warn('socket', 'Received binary data, closing connection');
  915. ws.close(1003, 'The mastodon streaming server does not support binary messages');
  916. return;
  917. }
  918. const message = data.toString('utf8');
  919. const json = parseJSON(message, session.request);
  920. if (!json) return;
  921. const { type, stream, ...params } = json;
  922. if (type === 'subscribe') {
  923. subscribeWebsocketToChannel(session, firstParam(stream), params);
  924. } else if (type === 'unsubscribe') {
  925. unsubscribeWebsocketFromChannel(session, firstParam(stream), params);
  926. } else {
  927. // Unknown action type
  928. }
  929. });
  930. subscribeWebsocketToSystemChannel(session);
  931. if (location.query.stream) {
  932. subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
  933. }
  934. });
  935. setInterval(() => {
  936. wss.clients.forEach(ws => {
  937. if (ws.isAlive === false) {
  938. ws.terminate();
  939. return;
  940. }
  941. ws.isAlive = false;
  942. ws.ping('', false);
  943. });
  944. }, 30000);
  945. attachServerWithConfig(server, address => {
  946. log.warn(`Worker ${workerId} now listening on ${address}`);
  947. });
  948. const onExit = () => {
  949. log.warn(`Worker ${workerId} exiting`);
  950. server.close();
  951. process.exit(0);
  952. };
  953. const onError = (err) => {
  954. log.error(err);
  955. server.close();
  956. process.exit(0);
  957. };
  958. process.on('SIGINT', onExit);
  959. process.on('SIGTERM', onExit);
  960. process.on('exit', onExit);
  961. process.on('uncaughtException', onError);
  962. };
  963. /**
  964. * @param {any} server
  965. * @param {function(string): void} [onSuccess]
  966. */
  967. const attachServerWithConfig = (server, onSuccess) => {
  968. if (process.env.SOCKET || process.env.PORT && isNaN(+process.env.PORT)) {
  969. server.listen(process.env.SOCKET || process.env.PORT, () => {
  970. if (onSuccess) {
  971. fs.chmodSync(server.address(), 0o666);
  972. onSuccess(server.address());
  973. }
  974. });
  975. } else {
  976. server.listen(+process.env.PORT || 4000, process.env.BIND || '127.0.0.1', () => {
  977. if (onSuccess) {
  978. onSuccess(`${server.address().address}:${server.address().port}`);
  979. }
  980. });
  981. }
  982. };
  983. /**
  984. * @param {function(Error=): void} onSuccess
  985. */
  986. const onPortAvailable = onSuccess => {
  987. const testServer = http.createServer();
  988. testServer.once('error', err => {
  989. onSuccess(err);
  990. });
  991. testServer.once('listening', () => {
  992. testServer.once('close', () => onSuccess());
  993. testServer.close();
  994. });
  995. attachServerWithConfig(testServer);
  996. };
  997. onPortAvailable(err => {
  998. if (err) {
  999. log.error('Could not start server, the port or socket is in use');
  1000. return;
  1001. }
  1002. throng({
  1003. workers: numWorkers,
  1004. lifetime: Infinity,
  1005. start: startWorker,
  1006. master: startMaster,
  1007. });
  1008. });