Source code for stable_datasets.features.image
"""Image feature codec."""
from __future__ import annotations
import io
from pathlib import Path
from .base import FeatureType
[docs]
class Image(FeatureType):
"""Image feature stored as raw bytes in Arrow."""
def __init__(self, encode_format: str = "PNG"):
self.encode_format = encode_format
[docs]
def to_arrow_type(self):
import pyarrow as pa
return pa.large_binary()
[docs]
def encode(self, value, *, cache_dir: Path | None = None) -> bytes | None:
return _encode_image_value(value, encode_format=self.encode_format)
def __repr__(self) -> str:
return f"Image(encode_format='{self.encode_format}')"
def _encode_image_value(img, encode_format: str = "PNG") -> bytes | None:
if img is None:
return None
if isinstance(img, bytes):
return img
if isinstance(img, str | Path):
with open(img, "rb") as f:
return f.read()
import numpy as np
from PIL import Image as PILImage
if isinstance(img, PILImage.Image):
src = getattr(img, "filename", None)
if src and Path(src).is_file():
with open(src, "rb") as f:
return f.read()
buf = io.BytesIO()
fmt = getattr(img, "format", None)
if fmt is None or img.mode in ("RGBA", "LA", "PA", "P"):
fmt = "PNG"
img.save(buf, format=fmt)
return buf.getvalue()
if isinstance(img, np.ndarray):
pil_img = PILImage.fromarray(img)
buf = io.BytesIO()
pil_img.save(buf, format=encode_format)
return buf.getvalue()
raise TypeError(f"Cannot encode image of type {type(img)}")