create_images.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. import re
  3. from ..typing import CreateResult, Messages
  4. from ..base_provider import BaseProvider, ProviderType
  5. from .. import debug
  6. system_message = """
  7. You can generate custom images with the DALL-E 3 image generator.
  8. To generate a image with a prompt, do this:
  9. <img data-prompt=\"keywords for the image\">
  10. Don't use images with data uri. It is important to use a prompt instead.
  11. <img data-prompt=\"image caption\">
  12. """
  13. class CreateImagesProvider(BaseProvider):
  14. def __init__(
  15. self,
  16. provider: ProviderType,
  17. create_images: callable,
  18. system_message: str = system_message
  19. ) -> None:
  20. self.provider = provider
  21. self.create_images = create_images
  22. self.system_message = system_message
  23. self. __name__ = provider.__name__
  24. if hasattr(provider, "url"):
  25. self.url = provider.url
  26. self.working = provider.working
  27. self.supports_stream = provider.supports_stream
  28. def create_completion(
  29. self,
  30. model: str,
  31. messages: Messages,
  32. stream: bool = False,
  33. **kwargs
  34. ) -> CreateResult:
  35. messages.insert(0, {"role": "system", "content": self.system_message})
  36. image_placeholder = ""
  37. for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
  38. if image_placeholder or "<" in chunk:
  39. image_placeholder += chunk
  40. if ">" in image_placeholder:
  41. result = re.search(r'<img data-prompt="(.*?)"', image_placeholder)
  42. if result:
  43. prompt = result.group(1)
  44. if debug.logging:
  45. print(f"Create images with prompt: {prompt}")
  46. yield from self.create_images(prompt)
  47. else:
  48. yield image_placeholder
  49. image_placeholder = ""
  50. else:
  51. yield chunk
  52. async def create_async(
  53. self,
  54. model: str,
  55. messages: Messages,
  56. **kwargs
  57. ) -> str:
  58. messages.insert(0, {"role": "system", "content": self.system_message})
  59. response = await self.provider.create_async(model, messages, **kwargs)
  60. result = re.search(r'<img data-prompt="(.*?)">', response)
  61. if result:
  62. search = result.group(0)
  63. prompt = result.group(1)
  64. images = "".join([*self.create_images(prompt)])
  65. return response.replace(search, images)
  66. return response