diff --git a/README.md b/README.md index e030c40..51efedb 100644 --- a/README.md +++ b/README.md @@ -164,14 +164,14 @@ result = md.convert("test.pdf") print(result.text_content) ``` -To use Large Language Models for image descriptions, provide `llm_client` and `llm_model`: +To use Large Language Models for image descriptions (currently only for pptx and image files), provide `llm_client` and `llm_model`: ```python from markitdown import MarkItDown from openai import OpenAI client = OpenAI() -md = MarkItDown(llm_client=client, llm_model="gpt-4o") +md = MarkItDown(llm_client=client, llm_model="gpt-4o", llm_prompt="optional custom prompt") result = md.convert("example.jpg") print(result.text_content) ``` diff --git a/packages/markitdown/src/markitdown/_markitdown.py b/packages/markitdown/src/markitdown/_markitdown.py index 3027efc..702b10c 100644 --- a/packages/markitdown/src/markitdown/_markitdown.py +++ b/packages/markitdown/src/markitdown/_markitdown.py @@ -115,6 +115,7 @@ class MarkItDown: # TODO - remove these (see enable_builtins) self._llm_client: Any = None self._llm_model: Union[str | None] = None + self._llm_prompt: Union[str | None] = None self._exiftool_path: Union[str | None] = None self._style_map: Union[str | None] = None @@ -139,6 +140,7 @@ class MarkItDown: # TODO: Move these into converter constructors self._llm_client = kwargs.get("llm_client") self._llm_model = kwargs.get("llm_model") + self._llm_prompt = kwargs.get("llm_prompt") self._exiftool_path = kwargs.get("exiftool_path") self._style_map = kwargs.get("style_map") @@ -559,6 +561,9 @@ class MarkItDown: if "llm_model" not in _kwargs and self._llm_model is not None: _kwargs["llm_model"] = self._llm_model + if "llm_prompt" not in _kwargs and self._llm_prompt is not None: + _kwargs["llm_prompt"] = self._llm_prompt + if "style_map" not in _kwargs and self._style_map is not None: _kwargs["style_map"] = self._style_map diff --git a/packages/markitdown/tests/test_module_misc.py b/packages/markitdown/tests/test_module_misc.py index 447e28a..03e123d 100644 --- a/packages/markitdown/tests/test_module_misc.py +++ b/packages/markitdown/tests/test_module_misc.py @@ -4,6 +4,7 @@ import os import re import shutil import pytest +from unittest.mock import MagicMock from markitdown._uri_utils import parse_data_uri, file_uri_to_path @@ -370,6 +371,50 @@ def test_markitdown_exiftool() -> None: assert target in result.text_content +def test_markitdown_llm_parameters() -> None: + """Test that LLM parameters are correctly passed to the client.""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock( + content="Test caption with red circle and blue square 5bda1dd6" + ) + ) + ] + mock_client.chat.completions.create.return_value = mock_response + + test_prompt = "You are a professional test prompt." + markitdown = MarkItDown( + llm_client=mock_client, llm_model="gpt-4o", llm_prompt=test_prompt + ) + + # Test image file + markitdown.convert(os.path.join(TEST_FILES_DIR, "test_llm.jpg")) + + # Verify the prompt was passed to the OpenAI API + assert mock_client.chat.completions.create.called + call_args = mock_client.chat.completions.create.call_args + messages = call_args[1]["messages"] + assert len(messages) == 1 + assert messages[0]["content"][0]["text"] == test_prompt + + # Reset the mock for the next test + mock_client.chat.completions.create.reset_mock() + + # TODO: may only use one test after the llm caption method duplicate has been removed: + # https://github.com/microsoft/markitdown/pull/1254 + # Test PPTX file + markitdown.convert(os.path.join(TEST_FILES_DIR, "test.pptx")) + + # Verify the prompt was passed to the OpenAI API for PPTX images too + assert mock_client.chat.completions.create.called + call_args = mock_client.chat.completions.create.call_args + messages = call_args[1]["messages"] + assert len(messages) == 1 + assert messages[0]["content"][0]["text"] == test_prompt + + @pytest.mark.skipif( skip_llm, reason="do not run llm tests without a key", @@ -408,6 +453,7 @@ if __name__ == "__main__": test_speech_transcription, test_exceptions, test_markitdown_exiftool, + test_markitdown_llm_parameters, test_markitdown_llm, ]: print(f"Running {test.__name__}...", end="")