From fc1422f099d95878209c92b3e9e2f509fe8ca77e Mon Sep 17 00:00:00 2001
From: Ben Sima <ben@bsima.me>
Date: Wed, 4 Dec 2024 21:55:02 -0500
Subject: Add some mock tests of the Image endpoint

These were contributed in part by gptme, thanks!
---
 Biz/Storybook.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 91 insertions(+), 10 deletions(-)

(limited to 'Biz/Storybook.py')

diff --git a/Biz/Storybook.py b/Biz/Storybook.py
index 3659c37..c619ef8 100644
--- a/Biz/Storybook.py
+++ b/Biz/Storybook.py
@@ -38,6 +38,7 @@ import sys
 import time
 import typing
 import unittest
+import unittest.mock as mock
 import uvicorn
 
 MOCK = True
@@ -365,7 +366,7 @@ class Images(ludic.web.Endpoint[Image]):
         """
         if MOCK:
             # Simulate slow image generation
-            time.sleep(3)
+            time.sleep(1)
             return ludic.web.responses.FileResponse(
                 DATA_DIR / "images" / "placeholder.jpg",
             )
@@ -449,18 +450,98 @@ class ImagesTest(unittest.TestCase):
     def setUp(self) -> None:
         """Create test client."""
         self.client = starlette.testclient.TestClient(app)
+        self.story_id = "Uk"
+        self.page = 1
+        self.valid_prompt = {"text": "A beautiful sunset over the ocean"}
+
+    def test_image_get_existing(self) -> None:
+        """Test retrieving an existing image."""
+        # Arrange: Mock the load_by_id method to simulate an existing image
+        data = {"path": DATA_DIR / "images" / "placeholder.jpg"}
+        mock_dict = mock.MagicMock()
+        mock_dict.__getitem__.side_effect = data.__getitem__
+        with mock.patch.object(
+            Images,
+            "load_by_id",
+            return_value=mock_dict,
+        ):
+            # Act: Send a GET request to retrieve the image
+            response = self.client.get(
+                app.url_path_for(
+                    "Images",
+                    story_id=self.story_id,
+                    page=self.page,
+                ),
+            )
+            # Assert: Check that the response status is 200
+            self.assertEqual(response.status_code, 200)
 
-    def test_image_post(self) -> None:
-        """Can POST an Image successfully."""
-        response = self.client.post(
-            app.url_path_for(
-                "Images",
-                story_id="Uk",
-                page=1,
+    def test_image_get_nonexistent(self) -> None:
+        """Test retrieving a non-existent image."""
+        # Act: Send a GET request for a non-existent image
+        response = self.client.get(
+            app.url_path_for("Images", story_id="nonexistent", page=self.page),
+        )
+        # Assert: Check that the response status is 404
+        self.assertEqual(response.status_code, 404)
+
+    def test_image_post_valid(self) -> None:
+        """Test creating an image with valid data."""
+        # Arrange: Mock the OpenAI API and file system operations
+        with (
+            mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai,
+            mock.patch(
+                "Biz.Storybook.pathlib.Path.write_bytes",
             ),
-            data={"text": "lorem ipsum"},
+        ):
+            mock_openai.return_value.images.generate.return_value.data = [
+                mock.MagicMock(url="http://example.com/image.jpg"),
+            ]
+            # Act: Send a POST request with valid data
+            response = self.client.post(
+                app.url_path_for(
+                    "Images",
+                    story_id=self.story_id,
+                    page=self.page,
+                ),
+                data=self.valid_prompt,
+            )
+            # Assert: Check that the response status is 200
+            self.assertEqual(response.status_code, 200)
+
+    def test_image_post_invalid(self) -> None:
+        """Test creating an image with invalid data."""
+        # Act: Send a POST request with invalid data
+        response = self.client.post(
+            app.url_path_for("Images", story_id=self.story_id, page=self.page),
+            data={"invalid": "data"},
         )
-        self.assertEqual(response.status_code, 200)
+        # Assert: Check that the response status indicates an error
+        self.assertNotEqual(response.status_code, 200)
+
+    def test_image_put_overwrite(self) -> None:
+        """Test overwriting an existing image."""
+        # Arrange: Mock the OpenAI API and file system operations
+        with (
+            mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai,
+            mock.patch(
+                "Biz.Storybook.pathlib.Path.write_bytes",
+            ),
+        ):
+            mock_openai.return_value.images.generate.return_value.data = [
+                mock.MagicMock(url="http://example.com/image.jpg"),
+            ]
+            # Act: Send a PUT request to overwrite the image
+            response = self.client.put(
+                app.url_path_for(
+                    "Images",
+                    story_id=self.story_id,
+                    page=self.page,
+                ),
+                data=self.valid_prompt,
+            )
+            # Assert: Check that the response status is 200
+            self.assertEqual(response.status_code, 200)
 
 
 @app.endpoint("/stories/{story_id:str}")
-- 
cgit v1.2.3