image.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import re
  2. from io import BytesIO
  3. import base64
  4. from .typing import ImageType, Union
  5. from PIL import Image
  6. ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
  7. def to_image(image: ImageType) -> Image.Image:
  8. """
  9. Converts the input image to a PIL Image object.
  10. Args:
  11. image (Union[str, bytes, Image.Image]): The input image.
  12. Returns:
  13. Image.Image: The converted PIL Image object.
  14. """
  15. if isinstance(image, str):
  16. is_data_uri_an_image(image)
  17. image = extract_data_uri(image)
  18. if isinstance(image, bytes):
  19. is_accepted_format(image)
  20. image = Image.open(BytesIO(image))
  21. elif not isinstance(image, Image.Image):
  22. image = Image.open(image)
  23. copy = image.copy()
  24. copy.format = image.format
  25. image = copy
  26. return image
  27. def is_allowed_extension(filename: str) -> bool:
  28. """
  29. Checks if the given filename has an allowed extension.
  30. Args:
  31. filename (str): The filename to check.
  32. Returns:
  33. bool: True if the extension is allowed, False otherwise.
  34. """
  35. return '.' in filename and \
  36. filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  37. def is_data_uri_an_image(data_uri: str) -> bool:
  38. """
  39. Checks if the given data URI represents an image.
  40. Args:
  41. data_uri (str): The data URI to check.
  42. Raises:
  43. ValueError: If the data URI is invalid or the image format is not allowed.
  44. """
  45. # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
  46. if not re.match(r'data:image/(\w+);base64,', data_uri):
  47. raise ValueError("Invalid data URI image.")
  48. # Extract the image format from the data URI
  49. image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
  50. # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
  51. if image_format.lower() not in ALLOWED_EXTENSIONS:
  52. raise ValueError("Invalid image format (from mime file type).")
  53. def is_accepted_format(binary_data: bytes) -> bool:
  54. """
  55. Checks if the given binary data represents an image with an accepted format.
  56. Args:
  57. binary_data (bytes): The binary data to check.
  58. Raises:
  59. ValueError: If the image format is not allowed.
  60. """
  61. if binary_data.startswith(b'\xFF\xD8\xFF'):
  62. pass # It's a JPEG image
  63. elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
  64. pass # It's a PNG image
  65. elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
  66. pass # It's a GIF image
  67. elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
  68. pass # It's a JPEG image
  69. elif binary_data.startswith(b'\xFF\xD8'):
  70. pass # It's a JPEG image
  71. elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
  72. pass # It's a WebP image
  73. else:
  74. raise ValueError("Invalid image format (from magic code).")
  75. def extract_data_uri(data_uri: str) -> bytes:
  76. """
  77. Extracts the binary data from the given data URI.
  78. Args:
  79. data_uri (str): The data URI.
  80. Returns:
  81. bytes: The extracted binary data.
  82. """
  83. data = data_uri.split(",")[1]
  84. data = base64.b64decode(data)
  85. return data
  86. def get_orientation(image: Image.Image) -> int:
  87. """
  88. Gets the orientation of the given image.
  89. Args:
  90. image (Image.Image): The image.
  91. Returns:
  92. int: The orientation value.
  93. """
  94. exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
  95. if exif_data is not None:
  96. orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
  97. if orientation is not None:
  98. return orientation
  99. def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
  100. """
  101. Processes the given image by adjusting its orientation and resizing it.
  102. Args:
  103. img (Image.Image): The image to process.
  104. new_width (int): The new width of the image.
  105. new_height (int): The new height of the image.
  106. Returns:
  107. Image.Image: The processed image.
  108. """
  109. orientation = get_orientation(img)
  110. if orientation:
  111. if orientation > 4:
  112. img = img.transpose(Image.FLIP_LEFT_RIGHT)
  113. if orientation in [3, 4]:
  114. img = img.transpose(Image.ROTATE_180)
  115. if orientation in [5, 6]:
  116. img = img.transpose(Image.ROTATE_270)
  117. if orientation in [7, 8]:
  118. img = img.transpose(Image.ROTATE_90)
  119. img.thumbnail((new_width, new_height))
  120. return img
  121. def to_base64(image: Image.Image, compression_rate: float) -> str:
  122. """
  123. Converts the given image to a base64-encoded string.
  124. Args:
  125. image (Image.Image): The image to convert.
  126. compression_rate (float): The compression rate (0.0 to 1.0).
  127. Returns:
  128. str: The base64-encoded image.
  129. """
  130. output_buffer = BytesIO()
  131. image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
  132. return base64.b64encode(output_buffer.getvalue()).decode()
  133. def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
  134. """
  135. Formats the given images as a markdown string.
  136. Args:
  137. images: The images to format.
  138. prompt (str): The prompt for the images.
  139. preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
  140. Returns:
  141. str: The formatted markdown string.
  142. """
  143. if isinstance(images, list):
  144. images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
  145. images = "\n".join(images)
  146. else:
  147. images = f"[![{prompt}]({images})]({images})"
  148. start_flag = "<!-- generated images start -->\n"
  149. end_flag = "<!-- generated images end -->\n"
  150. return f"\n{start_flag}{images}\n{end_flag}\n"
  151. def to_bytes(image: Image.Image) -> bytes:
  152. """
  153. Converts the given image to bytes.
  154. Args:
  155. image (Image.Image): The image to convert.
  156. Returns:
  157. bytes: The image as bytes.
  158. """
  159. bytes_io = BytesIO()
  160. image.save(bytes_io, image.format)
  161. image.seek(0)
  162. return bytes_io.getvalue()
  163. class ImageResponse():
  164. def __init__(
  165. self,
  166. images: Union[str, list],
  167. alt: str,
  168. options: dict = {}
  169. ):
  170. self.images = images
  171. self.alt = alt
  172. self.options = options
  173. def __str__(self) -> str:
  174. return format_images_markdown(self.images, self.alt)
  175. def get(self, key: str):
  176. return self.options.get(key)