Llama2.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. from ..typing import AsyncResult, Messages
  4. from .base_provider import AsyncGeneratorProvider
  5. models = {
  6. "meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat",
  7. "meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat",
  8. "meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat",
  9. }
  10. class Llama2(AsyncGeneratorProvider):
  11. url = "https://www.llama2.ai"
  12. working = True
  13. supports_message_history = True
  14. @classmethod
  15. async def create_async_generator(
  16. cls,
  17. model: str,
  18. messages: Messages,
  19. proxy: str = None,
  20. **kwargs
  21. ) -> AsyncResult:
  22. if not model:
  23. model = "meta/llama-2-70b-chat"
  24. elif model in models:
  25. model = models[model]
  26. headers = {
  27. "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0",
  28. "Accept": "*/*",
  29. "Accept-Language": "de,en-US;q=0.7,en;q=0.3",
  30. "Accept-Encoding": "gzip, deflate, br",
  31. "Referer": f"{cls.url}/",
  32. "Content-Type": "text/plain;charset=UTF-8",
  33. "Origin": cls.url,
  34. "Connection": "keep-alive",
  35. "Sec-Fetch-Dest": "empty",
  36. "Sec-Fetch-Mode": "cors",
  37. "Sec-Fetch-Site": "same-origin",
  38. "Pragma": "no-cache",
  39. "Cache-Control": "no-cache",
  40. "TE": "trailers"
  41. }
  42. async with ClientSession(headers=headers) as session:
  43. prompt = format_prompt(messages)
  44. data = {
  45. "prompt": prompt,
  46. "model": model,
  47. "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."),
  48. "temperature": kwargs.get("temperature", 0.75),
  49. "topP": kwargs.get("top_p", 0.9),
  50. "maxTokens": kwargs.get("max_tokens", 8000),
  51. "image": None
  52. }
  53. started = False
  54. async with session.post(f"{cls.url}/api", json=data, proxy=proxy) as response:
  55. response.raise_for_status()
  56. async for chunk in response.content.iter_any():
  57. if not started:
  58. chunk = chunk.lstrip()
  59. started = True
  60. yield chunk.decode()
  61. def format_prompt(messages: Messages):
  62. messages = [
  63. f"[INST] {message['content']} [/INST]"
  64. if message["role"] == "user"
  65. else message["content"]
  66. for message in messages
  67. ]
  68. return "\n".join(messages) + "\n"