Cloudflare.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from pathlib import Path
  5. from ..typing import AsyncResult, Messages, Cookies
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
  7. from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
  8. from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
  9. from ..providers.response import FinishReason
  10. from ..cookies import get_cookies_dir
  11. from ..errors import ResponseStatusError, ModelNotFoundError
  12. class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
  13. label = "Cloudflare AI"
  14. url = "https://playground.ai.cloudflare.com"
  15. working = True
  16. use_nodriver = True
  17. api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
  18. models_url = "https://playground.ai.cloudflare.com/api/models"
  19. supports_stream = True
  20. supports_system_message = True
  21. supports_message_history = True
  22. default_model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast"
  23. model_aliases = {
  24. "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
  25. "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
  26. "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
  27. "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
  28. "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
  29. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
  30. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
  31. "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
  32. "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
  33. }
  34. _args: dict = None
  35. @classmethod
  36. def get_cache_file(cls) -> Path:
  37. return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  38. @classmethod
  39. def get_models(cls) -> str:
  40. if not cls.models:
  41. if cls._args is None:
  42. if has_nodriver:
  43. get_running_loop(check_nested=True)
  44. args = get_args_from_nodriver(cls.url)
  45. cls._args = asyncio.run(args)
  46. elif not has_curl_cffi:
  47. return cls.models
  48. else:
  49. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  50. with Session(**cls._args) as session:
  51. response = session.get(cls.models_url)
  52. cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
  53. try:
  54. raise_for_status(response)
  55. except ResponseStatusError:
  56. return cls.models
  57. json_data = response.json()
  58. cls.models = [model.get("name") for model in json_data.get("models")]
  59. return cls.models
  60. @classmethod
  61. async def create_async_generator(
  62. cls,
  63. model: str,
  64. messages: Messages,
  65. proxy: str = None,
  66. max_tokens: int = 2048,
  67. cookies: Cookies = None,
  68. timeout: int = 300,
  69. **kwargs
  70. ) -> AsyncResult:
  71. cache_file = cls.get_cache_file()
  72. if cls._args is None:
  73. if cache_file.exists():
  74. with cache_file.open("r") as f:
  75. cls._args = json.load(f)
  76. if has_nodriver:
  77. cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
  78. else:
  79. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  80. try:
  81. model = cls.get_model(model)
  82. except ModelNotFoundError:
  83. pass
  84. data = {
  85. "messages": messages,
  86. "lora": None,
  87. "model": model,
  88. "max_tokens": max_tokens,
  89. "stream": True
  90. }
  91. async with StreamSession(**cls._args) as session:
  92. async with session.post(
  93. cls.api_endpoint,
  94. json=data,
  95. ) as response:
  96. cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
  97. try:
  98. await raise_for_status(response)
  99. except ResponseStatusError:
  100. cls._args = None
  101. if cache_file.exists():
  102. cache_file.unlink()
  103. raise
  104. reason = None
  105. async for line in response.iter_lines():
  106. if line.startswith(b'data: '):
  107. if line == b'data: [DONE]':
  108. break
  109. try:
  110. content = json.loads(line[6:].decode())
  111. if content.get("response") and content.get("response") != '</s>':
  112. yield content['response']
  113. reason = "max_tokens"
  114. elif content.get("response") == '':
  115. reason = "stop"
  116. except Exception:
  117. continue
  118. if reason is not None:
  119. yield FinishReason(reason)
  120. with cache_file.open("w") as f:
  121. json.dump(cls._args, f)