helper.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from __future__ import annotations
  2. import random
  3. import string
  4. from ..typing import Messages, Cookies
  5. from .. import debug
  6. def format_prompt(messages: Messages, add_special_tokens=False) -> str:
  7. """
  8. Format a series of messages into a single string, optionally adding special tokens.
  9. Args:
  10. messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
  11. add_special_tokens (bool): Whether to add special formatting tokens.
  12. Returns:
  13. str: A formatted string containing all messages.
  14. """
  15. if not add_special_tokens and len(messages) <= 1:
  16. return messages[0]["content"]
  17. formatted = "\n".join([
  18. f'{message["role"].capitalize()}: {message["content"]}'
  19. for message in messages
  20. ])
  21. return f"{formatted}\nAssistant:"
  22. def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
  23. prompt = format_prompt(messages)
  24. start = len(prompt)
  25. if start > max_lenght:
  26. if len(messages) > 6:
  27. prompt = format_prompt(messages[:3] + messages[-3:])
  28. if len(prompt) > max_lenght:
  29. if len(messages) > 2:
  30. prompt = format_prompt([m for m in messages if m["role"] == "system"] + messages[-1:])
  31. if len(prompt) > max_lenght:
  32. prompt = messages[-1]["content"]
  33. debug.log(f"Messages trimmed from: {start} to: {len(prompt)}")
  34. return prompt
  35. def get_random_string(length: int = 10) -> str:
  36. """
  37. Generate a random string of specified length, containing lowercase letters and digits.
  38. Args:
  39. length (int, optional): Length of the random string to generate. Defaults to 10.
  40. Returns:
  41. str: A random string of the specified length.
  42. """
  43. return ''.join(
  44. random.choice(string.ascii_lowercase + string.digits)
  45. for _ in range(length)
  46. )
  47. def get_random_hex(length: int = 32) -> str:
  48. """
  49. Generate a random hexadecimal string with n length.
  50. Returns:
  51. str: A random hexadecimal string of n characters.
  52. """
  53. return ''.join(
  54. random.choice("abcdef" + string.digits)
  55. for _ in range(length)
  56. )
  57. def filter_none(**kwargs) -> dict:
  58. return {
  59. key: value
  60. for key, value in kwargs.items()
  61. if value is not None
  62. }
  63. def format_cookies(cookies: Cookies) -> str:
  64. return "; ".join([f"{k}={v}" for k, v in cookies.items()])