helper.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from __future__ import annotations
  2. import sys
  3. import asyncio
  4. import webbrowser
  5. import random
  6. import string
  7. import secrets
  8. import os
  9. from os import path
  10. from asyncio import AbstractEventLoop
  11. from platformdirs import user_config_dir
  12. from browser_cookie3 import (
  13. chrome,
  14. chromium,
  15. opera,
  16. opera_gx,
  17. brave,
  18. edge,
  19. vivaldi,
  20. firefox,
  21. _LinuxPasswordManager
  22. )
  23. from ..typing import Dict, Messages
  24. from .. import debug
  25. # Change event loop policy on windows
  26. if sys.platform == 'win32':
  27. if isinstance(
  28. asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
  29. ):
  30. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  31. # Local Cookie Storage
  32. _cookies: Dict[str, Dict[str, str]] = {}
  33. # If event loop is already running, handle nested event loops
  34. # If "nest_asyncio" is installed, patch the event loop.
  35. def get_event_loop() -> AbstractEventLoop:
  36. try:
  37. asyncio.get_running_loop()
  38. except RuntimeError:
  39. try:
  40. return asyncio.get_event_loop()
  41. except RuntimeError:
  42. asyncio.set_event_loop(asyncio.new_event_loop())
  43. return asyncio.get_event_loop()
  44. try:
  45. event_loop = asyncio.get_event_loop()
  46. if not hasattr(event_loop.__class__, "_nest_patched"):
  47. import nest_asyncio
  48. nest_asyncio.apply(event_loop)
  49. return event_loop
  50. except ImportError:
  51. raise RuntimeError(
  52. 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
  53. )
  54. def init_cookies():
  55. urls = [
  56. 'https://chat-gpt.org',
  57. 'https://www.aitianhu.com',
  58. 'https://chatgptfree.ai',
  59. 'https://gptchatly.com',
  60. 'https://bard.google.com',
  61. 'https://huggingface.co/chat',
  62. 'https://open-assistant.io/chat'
  63. ]
  64. browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
  65. def open_urls_in_browser(browser):
  66. b = webbrowser.get(browser)
  67. for url in urls:
  68. b.open(url, new=0, autoraise=True)
  69. for browser in browsers:
  70. try:
  71. open_urls_in_browser(browser)
  72. break
  73. except webbrowser.Error:
  74. continue
  75. # Check for broken dbus address in docker image
  76. if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
  77. _LinuxPasswordManager.get_password = lambda a, b: b"secret"
  78. # Load cookies for a domain from all supported browsers.
  79. # Cache the results in the "_cookies" variable.
  80. def get_cookies(domain_name=''):
  81. if domain_name in _cookies:
  82. return _cookies[domain_name]
  83. def g4f(domain_name):
  84. user_data_dir = user_config_dir("g4f")
  85. cookie_file = path.join(user_data_dir, "Default", "Cookies")
  86. return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
  87. cookies = {}
  88. for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
  89. try:
  90. cookie_jar = cookie_fn(domain_name=domain_name)
  91. if len(cookie_jar) and debug.logging:
  92. print(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
  93. for cookie in cookie_jar:
  94. if cookie.name not in cookies:
  95. cookies[cookie.name] = cookie.value
  96. except:
  97. pass
  98. _cookies[domain_name] = cookies
  99. return _cookies[domain_name]
  100. def format_prompt(messages: Messages, add_special_tokens=False) -> str:
  101. if not add_special_tokens and len(messages) <= 1:
  102. return messages[0]["content"]
  103. formatted = "\n".join([
  104. f'{message["role"].capitalize()}: {message["content"]}'
  105. for message in messages
  106. ])
  107. return f"{formatted}\nAssistant:"
  108. def get_random_string(length: int = 10) -> str:
  109. return ''.join(
  110. random.choice(string.ascii_lowercase + string.digits)
  111. for _ in range(length)
  112. )
  113. def get_random_hex() -> str:
  114. return secrets.token_hex(16).zfill(32)