diff --git a/Tests/test_imagepalette.py b/Tests/test_imagepalette.py index 10b89a2c0c2..526beb656e8 100644 --- a/Tests/test_imagepalette.py +++ b/Tests/test_imagepalette.py @@ -1,6 +1,6 @@ from __future__ import annotations -from io import BytesIO +import io from pathlib import Path import pytest @@ -23,6 +23,13 @@ def test_reload() -> None: assert_image_equal(im.convert("RGB"), original.convert("RGB")) +def test_save_fp() -> None: + palette = ImagePalette.ImagePalette() + with io.StringIO() as fp: + palette.save(fp) + assert not fp.closed + + def test_getcolor() -> None: palette = ImagePalette.ImagePalette() assert len(palette.palette) == 0 @@ -204,7 +211,7 @@ def test_2bit_palette(tmp_path: Path) -> None: def test_getpalette() -> None: - b = BytesIO(b"0 1\n1 2 3 4") + b = io.BytesIO(b"0 1\n1 2 3 4") p = PaletteFile.PaletteFile(b) palette, rawmode = p.getpalette() @@ -216,6 +223,6 @@ def test_invalid_palette() -> None: with pytest.raises(OSError): ImagePalette.load("Tests/images/hopper.jpg") - b = BytesIO(b"1" * 101) + b = io.BytesIO(b"1" * 101) with pytest.raises(SyntaxError, match="bad palette file"): PaletteFile.PaletteFile(b) diff --git a/src/PIL/ImagePalette.py b/src/PIL/ImagePalette.py index eae7aea8fc3..99ad2771b4b 100644 --- a/src/PIL/ImagePalette.py +++ b/src/PIL/ImagePalette.py @@ -191,19 +191,24 @@ def save(self, fp: str | IO[str]) -> None: if self.rawmode: msg = "palette contains raw palette data" raise ValueError(msg) + open_fp = False if isinstance(fp, str): fp = open(fp, "w") - fp.write("# Palette\n") - fp.write(f"# Mode: {self.mode}\n") - for i in range(256): - fp.write(f"{i}") - for j in range(i * len(self.mode), (i + 1) * len(self.mode)): - try: - fp.write(f" {self.palette[j]}") - except IndexError: - fp.write(" 0") - fp.write("\n") - fp.close() + open_fp = True + try: + fp.write("# Palette\n") + fp.write(f"# Mode: {self.mode}\n") + for i in range(256): + fp.write(f"{i}") + for j in range(i * len(self.mode), (i + 1) * len(self.mode)): + try: + fp.write(f" {self.palette[j]}") + except IndexError: + fp.write(" 0") + fp.write("\n") + finally: + if open_fp: + fp.close() # --------------------------------------------------------------------