create_provider.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import sys, re
  2. from pathlib import Path
  3. from os import path
  4. sys.path.append(str(Path(__file__).parent.parent.parent))
  5. import g4f
  6. g4f.debug.logging = True
  7. def read_code(text):
  8. if match := re.search(r"```(python|py|)\n(?P<code>[\S\s]+?)\n```", text):
  9. return match.group("code")
  10. def input_command():
  11. print("Enter/Paste the cURL command. Ctrl-D or Ctrl-Z ( windows ) to save it.")
  12. contents = []
  13. while True:
  14. try:
  15. line = input()
  16. except EOFError:
  17. break
  18. contents.append(line)
  19. return "\n".join(contents)
  20. name = input("Name: ")
  21. provider_path = f"g4f/Provider/{name}.py"
  22. example = """
  23. from __future__ import annotations
  24. from aiohttp import ClientSession
  25. from ..typing import AsyncResult, Messages
  26. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  27. from .helper import format_prompt
  28. class {name}(AsyncGeneratorProvider, ProviderModelMixin):
  29. label = ""
  30. url = "https://example.com"
  31. api_endpoint = "https://example.com/api/completion"
  32. working = True
  33. needs_auth = False
  34. supports_stream = True
  35. supports_system_message = True
  36. supports_message_history = True
  37. default_model = ''
  38. models = ['', '']
  39. model_aliases = {
  40. "alias1": "model1",
  41. }
  42. @classmethod
  43. def get_model(cls, model: str) -> str:
  44. if model in cls.models:
  45. return model
  46. elif model in cls.model_aliases:
  47. return cls.model_aliases[model]
  48. else:
  49. return cls.default_model
  50. @classmethod
  51. async def create_async_generator(
  52. cls,
  53. model: str,
  54. messages: Messages,
  55. proxy: str = None,
  56. **kwargs
  57. ) -> AsyncResult:
  58. model = cls.get_model(model)
  59. headers = {{
  60. "authority": "example.com",
  61. "accept": "application/json",
  62. "origin": cls.url,
  63. "referer": f"{{cls.url}}/chat",
  64. }}
  65. async with ClientSession(headers=headers) as session:
  66. prompt = format_prompt(messages)
  67. data = {{
  68. "prompt": prompt,
  69. "model": model,
  70. }}
  71. async with session.post(f"{{cls.url}}/api/chat", json=data, proxy=proxy) as response:
  72. response.raise_for_status()
  73. async for chunk in response.content:
  74. if chunk:
  75. yield chunk.decode()
  76. """
  77. if not path.isfile(provider_path):
  78. command = input_command()
  79. prompt = f"""
  80. Create a provider from a cURL command. The command is:
  81. ```bash
  82. {command}
  83. ```
  84. A example for a provider:
  85. ```python
  86. {example}
  87. ```
  88. The name for the provider class:
  89. {name}
  90. Replace "hello" with `format_prompt(messages)`.
  91. And replace "gpt-3.5-turbo" with `model`.
  92. """
  93. print("Create code...")
  94. response = []
  95. for chunk in g4f.ChatCompletion.create(
  96. model=g4f.models.default,
  97. messages=[{"role": "user", "content": prompt}],
  98. timeout=300,
  99. stream=True,
  100. ):
  101. print(chunk, end="", flush=True)
  102. response.append(chunk)
  103. print()
  104. response = "".join(response)
  105. if code := read_code(response):
  106. with open(provider_path, "w") as file:
  107. file.write(code)
  108. print("Saved at:", provider_path)
  109. with open("g4f/Provider/__init__.py", "a") as file:
  110. file.write(f"\nfrom .{name} import {name}")
  111. else:
  112. with open(provider_path, "r") as file:
  113. code = file.read()