fix: correctly pass custom llm prompt parameter (#1319)
* fix: correctly pass custom llm prompt parameter
This commit is contained in:
@@ -115,6 +115,7 @@ class MarkItDown:
|
||||
# TODO - remove these (see enable_builtins)
|
||||
self._llm_client: Any = None
|
||||
self._llm_model: Union[str | None] = None
|
||||
self._llm_prompt: Union[str | None] = None
|
||||
self._exiftool_path: Union[str | None] = None
|
||||
self._style_map: Union[str | None] = None
|
||||
|
||||
@@ -139,6 +140,7 @@ class MarkItDown:
|
||||
# TODO: Move these into converter constructors
|
||||
self._llm_client = kwargs.get("llm_client")
|
||||
self._llm_model = kwargs.get("llm_model")
|
||||
self._llm_prompt = kwargs.get("llm_prompt")
|
||||
self._exiftool_path = kwargs.get("exiftool_path")
|
||||
self._style_map = kwargs.get("style_map")
|
||||
|
||||
@@ -559,6 +561,9 @@ class MarkItDown:
|
||||
if "llm_model" not in _kwargs and self._llm_model is not None:
|
||||
_kwargs["llm_model"] = self._llm_model
|
||||
|
||||
if "llm_prompt" not in _kwargs and self._llm_prompt is not None:
|
||||
_kwargs["llm_prompt"] = self._llm_prompt
|
||||
|
||||
if "style_map" not in _kwargs and self._style_map is not None:
|
||||
_kwargs["style_map"] = self._style_map
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from markitdown._uri_utils import parse_data_uri, file_uri_to_path
|
||||
|
||||
@@ -370,6 +371,50 @@ def test_markitdown_exiftool() -> None:
|
||||
assert target in result.text_content
|
||||
|
||||
|
||||
def test_markitdown_llm_parameters() -> None:
|
||||
"""Test that LLM parameters are correctly passed to the client."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content="Test caption with red circle and blue square 5bda1dd6"
|
||||
)
|
||||
)
|
||||
]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
test_prompt = "You are a professional test prompt."
|
||||
markitdown = MarkItDown(
|
||||
llm_client=mock_client, llm_model="gpt-4o", llm_prompt=test_prompt
|
||||
)
|
||||
|
||||
# Test image file
|
||||
markitdown.convert(os.path.join(TEST_FILES_DIR, "test_llm.jpg"))
|
||||
|
||||
# Verify the prompt was passed to the OpenAI API
|
||||
assert mock_client.chat.completions.create.called
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
messages = call_args[1]["messages"]
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"][0]["text"] == test_prompt
|
||||
|
||||
# Reset the mock for the next test
|
||||
mock_client.chat.completions.create.reset_mock()
|
||||
|
||||
# TODO: may only use one test after the llm caption method duplicate has been removed:
|
||||
# https://github.com/microsoft/markitdown/pull/1254
|
||||
# Test PPTX file
|
||||
markitdown.convert(os.path.join(TEST_FILES_DIR, "test.pptx"))
|
||||
|
||||
# Verify the prompt was passed to the OpenAI API for PPTX images too
|
||||
assert mock_client.chat.completions.create.called
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
messages = call_args[1]["messages"]
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"][0]["text"] == test_prompt
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_llm,
|
||||
reason="do not run llm tests without a key",
|
||||
@@ -408,6 +453,7 @@ if __name__ == "__main__":
|
||||
test_speech_transcription,
|
||||
test_exceptions,
|
||||
test_markitdown_exiftool,
|
||||
test_markitdown_llm_parameters,
|
||||
test_markitdown_llm,
|
||||
]:
|
||||
print(f"Running {test.__name__}...", end="")
|
||||
|
||||
Reference in New Issue
Block a user