updater.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import contextlib
  8. import logging
  9. import os
  10. import subprocess
  11. import sys
  12. import time
  13. import typing
  14. import git
  15. from git import GitCommandError, Repo
  16. from hikkatl.extensions.html import CUSTOM_EMOJIS
  17. from hikkatl.tl.functions.messages import (
  18. GetDialogFiltersRequest,
  19. UpdateDialogFilterRequest,
  20. )
  21. from hikkatl.tl.types import DialogFilter, Message
  22. from .. import loader, main, utils, version
  23. from .._internal import restart
  24. from ..inline.types import InlineCall
  25. logger = logging.getLogger(__name__)
  26. @loader.tds
  27. class UpdaterMod(loader.Module):
  28. """Updates itself"""
  29. strings = {"name": "Updater"}
  30. def __init__(self):
  31. self.config = loader.ModuleConfig(
  32. loader.ConfigValue(
  33. "GIT_ORIGIN_URL",
  34. "https://github.com/hikariatama/Hikka",
  35. lambda: self.strings("origin_cfg_doc"),
  36. validator=loader.validators.Link(),
  37. )
  38. )
  39. @loader.command()
  40. async def restart(self, message: Message):
  41. args = utils.get_args_raw(message)
  42. secure_boot = any(trigger in args for trigger in {"--secure-boot", "-sb"})
  43. try:
  44. if (
  45. "-f" in args
  46. or not self.inline.init_complete
  47. or not await self.inline.form(
  48. message=message,
  49. text=self.strings(
  50. "secure_boot_confirm" if secure_boot else "restart_confirm"
  51. ),
  52. reply_markup=[
  53. {
  54. "text": self.strings("btn_restart"),
  55. "callback": self.inline_restart,
  56. "args": (secure_boot,),
  57. },
  58. {"text": self.strings("cancel"), "action": "close"},
  59. ],
  60. )
  61. ):
  62. raise
  63. except Exception:
  64. await self.restart_common(message, secure_boot)
  65. async def inline_restart(self, call: InlineCall, secure_boot: bool = False):
  66. await self.restart_common(call, secure_boot=secure_boot)
  67. async def process_restart_message(self, msg_obj: typing.Union[InlineCall, Message]):
  68. self.set(
  69. "selfupdatemsg",
  70. (
  71. msg_obj.inline_message_id
  72. if hasattr(msg_obj, "inline_message_id")
  73. else f"{utils.get_chat_id(msg_obj)}:{msg_obj.id}"
  74. ),
  75. )
  76. async def restart_common(
  77. self,
  78. msg_obj: typing.Union[InlineCall, Message],
  79. secure_boot: bool = False,
  80. ):
  81. if (
  82. hasattr(msg_obj, "form")
  83. and isinstance(msg_obj.form, dict)
  84. and "uid" in msg_obj.form
  85. and msg_obj.form["uid"] in self.inline._units
  86. and "message" in self.inline._units[msg_obj.form["uid"]]
  87. ):
  88. message = self.inline._units[msg_obj.form["uid"]]["message"]
  89. else:
  90. message = msg_obj
  91. if secure_boot:
  92. self._db.set(loader.__name__, "secure_boot", True)
  93. msg_obj = await utils.answer(
  94. msg_obj,
  95. self.strings("restarting_caption").format(
  96. utils.get_platform_emoji()
  97. if self._client.hikka_me.premium
  98. and CUSTOM_EMOJIS
  99. and isinstance(msg_obj, Message)
  100. else "Hikka"
  101. ),
  102. )
  103. await self.process_restart_message(msg_obj)
  104. self.set("restart_ts", time.time())
  105. await self._db.remote_force_save()
  106. if "LAVHOST" in os.environ:
  107. os.system("lavhost restart")
  108. return
  109. with contextlib.suppress(Exception):
  110. await main.hikka.web.stop()
  111. handler = logging.getLogger().handlers[0]
  112. handler.setLevel(logging.CRITICAL)
  113. for client in self.allclients:
  114. # Terminate main loop of all running clients
  115. # Won't work if not all clients are ready
  116. if client is not message.client:
  117. await client.disconnect()
  118. await message.client.disconnect()
  119. restart()
  120. async def download_common(self):
  121. try:
  122. repo = Repo(os.path.dirname(utils.get_base_dir()))
  123. origin = repo.remote("origin")
  124. r = origin.pull()
  125. new_commit = repo.head.commit
  126. for info in r:
  127. if info.old_commit:
  128. for d in new_commit.diff(info.old_commit):
  129. if d.b_path == "requirements.txt":
  130. return True
  131. return False
  132. except git.exc.InvalidGitRepositoryError:
  133. repo = Repo.init(os.path.dirname(utils.get_base_dir()))
  134. origin = repo.create_remote("origin", self.config["GIT_ORIGIN_URL"])
  135. origin.fetch()
  136. repo.create_head("master", origin.refs.master)
  137. repo.heads.master.set_tracking_branch(origin.refs.master)
  138. repo.heads.master.checkout(True)
  139. return False
  140. @staticmethod
  141. def req_common():
  142. # Now we have downloaded new code, install requirements
  143. logger.debug("Installing new requirements...")
  144. try:
  145. subprocess.run(
  146. [
  147. sys.executable,
  148. "-m",
  149. "pip",
  150. "install",
  151. "-r",
  152. os.path.join(
  153. os.path.dirname(utils.get_base_dir()),
  154. "requirements.txt",
  155. ),
  156. "--user",
  157. ],
  158. check=True,
  159. )
  160. except subprocess.CalledProcessError:
  161. logger.exception("Req install failed")
  162. @loader.command()
  163. async def update(self, message: Message):
  164. try:
  165. args = utils.get_args_raw(message)
  166. current = utils.get_git_hash()
  167. upcoming = next(
  168. git.Repo().iter_commits(f"origin/{version.branch}", max_count=1)
  169. ).hexsha
  170. if (
  171. "-f" in args
  172. or not self.inline.init_complete
  173. or not await self.inline.form(
  174. message=message,
  175. text=(
  176. self.strings("update_confirm").format(
  177. current, current[:8], upcoming, upcoming[:8]
  178. )
  179. if upcoming != current
  180. else self.strings("no_update")
  181. ),
  182. reply_markup=[
  183. {
  184. "text": self.strings("btn_update"),
  185. "callback": self.inline_update,
  186. },
  187. {"text": self.strings("cancel"), "action": "close"},
  188. ],
  189. )
  190. ):
  191. raise
  192. except Exception:
  193. await self.inline_update(message)
  194. async def inline_update(
  195. self,
  196. msg_obj: typing.Union[InlineCall, Message],
  197. hard: bool = False,
  198. ):
  199. # We don't really care about asyncio at this point, as we are shutting down
  200. if hard:
  201. os.system(f"cd {utils.get_base_dir()} && cd .. && git reset --hard HEAD")
  202. try:
  203. if "LAVHOST" in os.environ:
  204. msg_obj = await utils.answer(
  205. msg_obj,
  206. self.strings("lavhost_update").format(
  207. "</b><emoji document_id=5192756799647785066>✌️</emoji><emoji"
  208. " document_id=5193117564015747203>✌️</emoji><emoji"
  209. " document_id=5195050806105087456>✌️</emoji><emoji"
  210. " document_id=5195457642587233944>✌️</emoji><b>"
  211. if self._client.hikka_me.premium
  212. and CUSTOM_EMOJIS
  213. and isinstance(msg_obj, Message)
  214. else "lavHost"
  215. ),
  216. )
  217. await self.process_restart_message(msg_obj)
  218. os.system("lavhost update")
  219. return
  220. with contextlib.suppress(Exception):
  221. msg_obj = await utils.answer(msg_obj, self.strings("downloading"))
  222. req_update = await self.download_common()
  223. with contextlib.suppress(Exception):
  224. msg_obj = await utils.answer(msg_obj, self.strings("installing"))
  225. if req_update:
  226. self.req_common()
  227. await self.restart_common(msg_obj)
  228. except GitCommandError:
  229. if not hard:
  230. await self.inline_update(msg_obj, True)
  231. return
  232. logger.critical("Got update loop. Update manually via .terminal")
  233. @loader.command()
  234. async def source(self, message: Message):
  235. await utils.answer(
  236. message,
  237. self.strings("source").format(self.config["GIT_ORIGIN_URL"]),
  238. )
  239. async def client_ready(self):
  240. if self.get("selfupdatemsg") is not None:
  241. try:
  242. await self.update_complete()
  243. except Exception:
  244. logger.exception("Failed to complete update!")
  245. if self.get("do_not_create", False):
  246. return
  247. try:
  248. await self._add_folder()
  249. except Exception:
  250. logger.exception("Failed to add folder!")
  251. self.set("do_not_create", True)
  252. async def _add_folder(self):
  253. folders = await self._client(GetDialogFiltersRequest())
  254. if any(getattr(folder, "title", None) == "hikka" for folder in folders):
  255. return
  256. try:
  257. folder_id = (
  258. max(
  259. folders,
  260. key=lambda x: x.id,
  261. ).id
  262. + 1
  263. )
  264. except ValueError:
  265. folder_id = 2
  266. try:
  267. await self._client(
  268. UpdateDialogFilterRequest(
  269. folder_id,
  270. DialogFilter(
  271. folder_id,
  272. title="hikka",
  273. pinned_peers=(
  274. [
  275. await self._client.get_input_entity(
  276. self._client.loader.inline.bot_id
  277. )
  278. ]
  279. if self._client.loader.inline.init_complete
  280. else []
  281. ),
  282. include_peers=[
  283. await self._client.get_input_entity(dialog.entity)
  284. async for dialog in self._client.iter_dialogs(
  285. None,
  286. ignore_migrated=True,
  287. )
  288. if dialog.name
  289. in {
  290. "hikka-logs",
  291. "hikka-onload",
  292. "hikka-assets",
  293. "hikka-backups",
  294. "hikka-acc-switcher",
  295. "silent-tags",
  296. }
  297. and dialog.is_channel
  298. and (
  299. dialog.entity.participants_count == 1
  300. or dialog.entity.participants_count == 2
  301. and dialog.name in {"hikka-logs", "silent-tags"}
  302. )
  303. or (
  304. self._client.loader.inline.init_complete
  305. and dialog.entity.id
  306. == self._client.loader.inline.bot_id
  307. )
  308. or dialog.entity.id
  309. in [
  310. 1554874075,
  311. 1697279580,
  312. 1679998924,
  313. ] # official hikka chats
  314. ],
  315. emoticon="🐱",
  316. exclude_peers=[],
  317. contacts=False,
  318. non_contacts=False,
  319. groups=False,
  320. broadcasts=False,
  321. bots=False,
  322. exclude_muted=False,
  323. exclude_read=False,
  324. exclude_archived=False,
  325. ),
  326. )
  327. )
  328. except Exception:
  329. logger.critical(
  330. "Can't create Hikka folder. Possible reasons are:\n"
  331. "- User reached the limit of folders in Telegram\n"
  332. "- User got floodwait\n"
  333. "Ignoring error and adding folder addition to ignore list"
  334. )
  335. async def update_complete(self):
  336. logger.debug("Self update successful! Edit message")
  337. start = self.get("restart_ts")
  338. try:
  339. took = round(time.time() - start)
  340. except Exception:
  341. took = "n/a"
  342. msg = self.strings("success").format(utils.ascii_face(), took)
  343. ms = self.get("selfupdatemsg")
  344. if ":" in str(ms):
  345. chat_id, message_id = ms.split(":")
  346. chat_id, message_id = int(chat_id), int(message_id)
  347. await self._client.edit_message(chat_id, message_id, msg)
  348. return
  349. await self.inline.bot.edit_message_text(
  350. inline_message_id=ms,
  351. text=self.inline.sanitise_text(msg),
  352. )
  353. async def full_restart_complete(self, secure_boot: bool = False):
  354. start = self.get("restart_ts")
  355. try:
  356. took = round(time.time() - start)
  357. except Exception:
  358. took = "n/a"
  359. self.set("restart_ts", None)
  360. ms = self.get("selfupdatemsg")
  361. msg = self.strings(
  362. "secure_boot_complete" if secure_boot else "full_success"
  363. ).format(utils.ascii_face(), took)
  364. if ms is None:
  365. return
  366. self.set("selfupdatemsg", None)
  367. if ":" in str(ms):
  368. chat_id, message_id = ms.split(":")
  369. chat_id, message_id = int(chat_id), int(message_id)
  370. await self._client.edit_message(chat_id, message_id, msg)
  371. await asyncio.sleep(60)
  372. await self._client.delete_messages(chat_id, message_id)
  373. return
  374. await self.inline.bot.edit_message_text(
  375. inline_message_id=ms,
  376. text=self.inline.sanitise_text(msg),
  377. )