helper.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. import asyncio
  3. import webbrowser
  4. import random
  5. import string
  6. import secrets
  7. import os
  8. from os import path
  9. from asyncio import AbstractEventLoop, BaseEventLoop
  10. from platformdirs import user_config_dir
  11. from browser_cookie3 import (
  12. chrome,
  13. chromium,
  14. opera,
  15. opera_gx,
  16. brave,
  17. edge,
  18. vivaldi,
  19. firefox,
  20. _LinuxPasswordManager
  21. )
  22. from ..typing import Dict, Messages
  23. from .. import debug
  24. # Local Cookie Storage
  25. _cookies: Dict[str, Dict[str, str]] = {}
  26. # If loop closed or not set, create new event loop.
  27. # If event loop is already running, handle nested event loops.
  28. # If "nest_asyncio" is installed, patch the event loop.
  29. def get_event_loop() -> AbstractEventLoop:
  30. try:
  31. loop = asyncio.get_event_loop()
  32. if isinstance(loop, BaseEventLoop):
  33. loop._check_closed()
  34. except RuntimeError:
  35. loop = asyncio.new_event_loop()
  36. asyncio.set_event_loop(loop)
  37. try:
  38. # Is running event loop
  39. asyncio.get_running_loop()
  40. if not hasattr(loop.__class__, "_nest_patched"):
  41. import nest_asyncio
  42. nest_asyncio.apply(loop)
  43. except RuntimeError:
  44. # No running event loop
  45. pass
  46. except ImportError:
  47. raise RuntimeError(
  48. 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
  49. )
  50. return loop
  51. def init_cookies():
  52. urls = [
  53. 'https://chat-gpt.org',
  54. 'https://www.aitianhu.com',
  55. 'https://chatgptfree.ai',
  56. 'https://gptchatly.com',
  57. 'https://bard.google.com',
  58. 'https://huggingface.co/chat',
  59. 'https://open-assistant.io/chat'
  60. ]
  61. browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
  62. def open_urls_in_browser(browser):
  63. b = webbrowser.get(browser)
  64. for url in urls:
  65. b.open(url, new=0, autoraise=True)
  66. for browser in browsers:
  67. try:
  68. open_urls_in_browser(browser)
  69. break
  70. except webbrowser.Error:
  71. continue
  72. # Check for broken dbus address in docker image
  73. if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
  74. _LinuxPasswordManager.get_password = lambda a, b: b"secret"
  75. # Load cookies for a domain from all supported browsers.
  76. # Cache the results in the "_cookies" variable.
  77. def get_cookies(domain_name=''):
  78. if domain_name in _cookies:
  79. return _cookies[domain_name]
  80. def g4f(domain_name):
  81. user_data_dir = user_config_dir("g4f")
  82. cookie_file = path.join(user_data_dir, "Default", "Cookies")
  83. return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
  84. cookies = {}
  85. for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
  86. try:
  87. cookie_jar = cookie_fn(domain_name=domain_name)
  88. if len(cookie_jar) and debug.logging:
  89. print(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
  90. for cookie in cookie_jar:
  91. if cookie.name not in cookies:
  92. cookies[cookie.name] = cookie.value
  93. except:
  94. pass
  95. _cookies[domain_name] = cookies
  96. return _cookies[domain_name]
  97. def format_prompt(messages: Messages, add_special_tokens=False) -> str:
  98. if not add_special_tokens and len(messages) <= 1:
  99. return messages[0]["content"]
  100. formatted = "\n".join([
  101. f'{message["role"].capitalize()}: {message["content"]}'
  102. for message in messages
  103. ])
  104. return f"{formatted}\nAssistant:"
  105. def get_random_string(length: int = 10) -> str:
  106. return ''.join(
  107. random.choice(string.ascii_lowercase + string.digits)
  108. for _ in range(length)
  109. )
  110. def get_random_hex() -> str:
  111. return secrets.token_hex(16).zfill(32)