Replicate.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  3. from ..helper import format_prompt, filter_none
  4. from ...typing import AsyncResult, Messages
  5. from ...requests import raise_for_status
  6. from ...requests.aiohttp import StreamSession
  7. from ...errors import ResponseError, MissingAuthError
  8. class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://replicate.com"
  10. working = True
  11. needs_auth = True
  12. default_model = "meta/meta-llama-3-70b-instruct"
  13. model_aliases = {
  14. "meta-llama/Meta-Llama-3-70B-Instruct": default_model
  15. }
  16. @classmethod
  17. async def create_async_generator(
  18. cls,
  19. model: str,
  20. messages: Messages,
  21. api_key: str = None,
  22. proxy: str = None,
  23. timeout: int = 180,
  24. system_prompt: str = None,
  25. max_new_tokens: int = None,
  26. temperature: float = None,
  27. top_p: float = None,
  28. top_k: float = None,
  29. stop: list = None,
  30. extra_data: dict = {},
  31. headers: dict = {
  32. "accept": "application/json",
  33. },
  34. **kwargs
  35. ) -> AsyncResult:
  36. model = cls.get_model(model)
  37. if cls.needs_auth and api_key is None:
  38. raise MissingAuthError("api_key is missing")
  39. if api_key is not None:
  40. headers["Authorization"] = f"Bearer {api_key}"
  41. api_base = "https://api.replicate.com/v1/models/"
  42. else:
  43. api_base = "https://replicate.com/api/models/"
  44. async with StreamSession(
  45. proxy=proxy,
  46. headers=headers,
  47. timeout=timeout
  48. ) as session:
  49. data = {
  50. "stream": True,
  51. "input": {
  52. "prompt": format_prompt(messages),
  53. **filter_none(
  54. system_prompt=system_prompt,
  55. max_new_tokens=max_new_tokens,
  56. temperature=temperature,
  57. top_p=top_p,
  58. top_k=top_k,
  59. stop_sequences=",".join(stop) if stop else None
  60. ),
  61. **extra_data
  62. },
  63. }
  64. url = f"{api_base.rstrip('/')}/{model}/predictions"
  65. async with session.post(url, json=data) as response:
  66. message = "Model not found" if response.status == 404 else None
  67. await raise_for_status(response, message)
  68. result = await response.json()
  69. if "id" not in result:
  70. raise ResponseError(f"Invalid response: {result}")
  71. async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response:
  72. await raise_for_status(response)
  73. event = None
  74. async for line in response.iter_lines():
  75. if line.startswith(b"event: "):
  76. event = line[7:]
  77. if event == b"done":
  78. break
  79. elif event == b"output":
  80. if line.startswith(b"data: "):
  81. new_text = line[6:].decode()
  82. if new_text:
  83. yield new_text
  84. else:
  85. yield "\n"