diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py index 99cd9928..1bf7fc92 100644 --- a/diffsynth/core/data/operators.py +++ b/diffsynth/core/data/operators.py @@ -1,9 +1,27 @@ import math, warnings import torch, torchvision, imageio, os import imageio.v3 as iio +import urllib.request +from io import BytesIO from PIL import Image +def _is_url(path): + return isinstance(path, str) and path.startswith(("http://", "https://")) + + +def _load_image_from_url(url, timeout=30): + # A custom User-Agent avoids the 403 that many CDNs / image hosts return for the default urllib agent. + request = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0 (compatible; DiffSynth-Studio)"}) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + image = Image.open(BytesIO(response.read())) + image.load() # force a full decode while the byte buffer is still in scope + return image + except Exception as e: + raise RuntimeError(f"Failed to load image from URL '{url}': {e}") from e + + class DataProcessingPipeline: def __init__(self, operators=None): self.operators: list[DataProcessingOperator] = [] if operators is None else operators @@ -54,12 +72,16 @@ def __call__(self, data): class LoadImage(DataProcessingOperator): - def __init__(self, convert_RGB=True, convert_RGBA=False): + def __init__(self, convert_RGB=True, convert_RGBA=False, timeout=30): self.convert_RGB = convert_RGB self.convert_RGBA = convert_RGBA - + self.timeout = timeout + def __call__(self, data: str): - image = Image.open(data) + if _is_url(data): + image = _load_image_from_url(data, timeout=self.timeout) + else: + image = Image.open(data) if self.convert_RGB: image = image.convert("RGB") if self.convert_RGBA: image = image.convert("RGBA") return image @@ -241,6 +263,9 @@ def __init__(self, base_path=""): self.base_path = base_path def __call__(self, data): + # A remote URL is already absolute; joining it with base_path would corrupt it. + if _is_url(data): + return data return os.path.join(self.base_path, data)