Replicate.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from __future__ import annotations
  2. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  3. from ..helper import format_prompt, filter_none
  4. from ...typing import AsyncResult, Messages
  5. from ...requests import raise_for_status
  6. from ...requests.aiohttp import StreamSession
  7. from ...errors import ResponseError, MissingAuthError
  8. class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://replicate.com"
  10. login_url = "https://replicate.com/account/api-tokens"
  11. working = True
  12. needs_auth = True
  13. default_model = "meta/meta-llama-3-70b-instruct"
  14. models = [default_model]
  15. @classmethod
  16. async def create_async_generator(
  17. cls,
  18. model: str,
  19. messages: Messages,
  20. api_key: str = None,
  21. proxy: str = None,
  22. timeout: int = 180,
  23. system_prompt: str = None,
  24. max_tokens: int = None,
  25. temperature: float = None,
  26. top_p: float = None,
  27. top_k: float = None,
  28. stop: list = None,
  29. extra_data: dict = {},
  30. headers: dict = {
  31. "accept": "application/json",
  32. },
  33. **kwargs
  34. ) -> AsyncResult:
  35. model = cls.get_model(model)
  36. if cls.needs_auth and api_key is None:
  37. raise MissingAuthError("api_key is missing")
  38. if api_key is not None:
  39. headers["Authorization"] = f"Bearer {api_key}"
  40. api_base = "https://api.replicate.com/v1/models/"
  41. else:
  42. api_base = "https://replicate.com/api/models/"
  43. async with StreamSession(
  44. proxy=proxy,
  45. headers=headers,
  46. timeout=timeout
  47. ) as session:
  48. data = {
  49. "stream": True,
  50. "input": {
  51. "prompt": format_prompt(messages),
  52. **filter_none(
  53. system_prompt=system_prompt,
  54. max_new_tokens=max_tokens,
  55. temperature=temperature,
  56. top_p=top_p,
  57. top_k=top_k,
  58. stop_sequences=",".join(stop) if stop else None
  59. ),
  60. **extra_data
  61. },
  62. }
  63. url = f"{api_base.rstrip('/')}/{model}/predictions"
  64. async with session.post(url, json=data) as response:
  65. message = "Model not found" if response.status == 404 else None
  66. await raise_for_status(response, message)
  67. result = await response.json()
  68. if "id" not in result:
  69. raise ResponseError(f"Invalid response: {result}")
  70. async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response:
  71. await raise_for_status(response)
  72. event = None
  73. async for line in response.iter_lines():
  74. if line.startswith(b"event: "):
  75. event = line[7:]
  76. if event == b"done":
  77. break
  78. elif event == b"output":
  79. if line.startswith(b"data: "):
  80. new_text = line[6:].decode()
  81. if new_text:
  82. yield new_text
  83. else:
  84. yield "\n"