Cerebras.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import annotations
  2. import requests
  3. from aiohttp import ClientSession
  4. from .OpenaiAPI import OpenaiAPI
  5. from ...typing import AsyncResult, Messages, Cookies
  6. from ...requests.raise_for_status import raise_for_status
  7. from ...cookies import get_cookies
  8. class Cerebras(OpenaiAPI):
  9. label = "Cerebras Inference"
  10. url = "https://inference.cerebras.ai/"
  11. working = True
  12. default_model = "llama3.1-70b"
  13. fallback_models = [
  14. "llama3.1-70b",
  15. "llama3.1-8b",
  16. ]
  17. model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
  18. @classmethod
  19. def get_models(cls, api_key: str = None):
  20. if not cls.models:
  21. try:
  22. headers = {}
  23. if api_key:
  24. headers["authorization"] = f"Bearer ${api_key}"
  25. response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
  26. raise_for_status(response)
  27. data = response.json()
  28. cls.models = [model.get("model") for model in data.get("models")]
  29. except Exception:
  30. cls.models = cls.fallback_models
  31. return cls.models
  32. @classmethod
  33. async def create_async_generator(
  34. cls,
  35. model: str,
  36. messages: Messages,
  37. api_base: str = "https://api.cerebras.ai/v1",
  38. api_key: str = None,
  39. cookies: Cookies = None,
  40. **kwargs
  41. ) -> AsyncResult:
  42. if api_key is None and cookies is None:
  43. cookies = get_cookies(".cerebras.ai")
  44. async with ClientSession(cookies=cookies) as session:
  45. async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
  46. raise_for_status(response)
  47. data = await response.json()
  48. if data:
  49. api_key = data.get("user", {}).get("demoApiKey")
  50. async for chunk in super().create_async_generator(
  51. model, messages,
  52. api_base=api_base,
  53. impersonate="chrome",
  54. api_key=api_key,
  55. headers={
  56. "User-Agent": "ex/JS 1.5.0",
  57. },
  58. **kwargs
  59. ):
  60. yield chunk