From 03a7843a0ae0b073563334f7b52019522d9393f6 Mon Sep 17 00:00:00 2001 From: Adam Fourney Date: Tue, 17 Dec 2024 13:22:48 -0800 Subject: [PATCH] Added deprecation warnings for mlm_* arguments. --- src/markitdown/_markitdown.py | 38 +++++++++++++++++++++++++++++++++-- tests/test_markitdown.py | 37 ++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/src/markitdown/_markitdown.py b/src/markitdown/_markitdown.py index a7fb28a..9400274 100644 --- a/src/markitdown/_markitdown.py +++ b/src/markitdown/_markitdown.py @@ -15,7 +15,7 @@ import traceback import zipfile from typing import Any, Dict, List, Optional, Union from urllib.parse import parse_qs, quote, unquote, urlparse, urlunparse -from warnings import catch_warnings +from warnings import warn, resetwarnings, catch_warnings import mammoth import markdownify @@ -44,6 +44,8 @@ try: IS_AUDIO_TRANSCRIPTION_CAPABLE = True except ModuleNotFoundError: pass +finally: + resetwarnings() # Optional YouTube transcription support try: @@ -1010,14 +1012,46 @@ class MarkItDown: self, requests_session: Optional[requests.Session] = None, llm_client: Optional[Any] = None, - llm_model: Optional[Any] = None, + llm_model: Optional[str] = None, style_map: Optional[str] = None, + # Deprecated + mlm_client: Optional[Any] = None, + mlm_model: Optional[str] = None, ): if requests_session is None: self._requests_session = requests.Session() else: self._requests_session = requests_session + # Handle deprecation notices + ############################# + if mlm_client is not None: + if llm_client is None: + warn( + "'mlm_client' is deprecated, and was renamed 'llm_client'.", + DeprecationWarning, + ) + llm_client = mlm_client + mlm_client = None + else: + raise ValueError( + "'mlm_client' is deprecated, and was renamed 'llm_client'. Do not use both at the same time. Just use 'llm_client' instead." + ) + + if mlm_model is not None: + if llm_model is None: + warn( + "'mlm_model' is deprecated, and was renamed 'llm_model'.", + DeprecationWarning, + ) + llm_model = mlm_model + mlm_model = None + else: + raise ValueError( + "'mlm_model' is deprecated, and was renamed 'llm_model'. Do not use both at the same time. Just use 'llm_model' instead." + ) + ############################# + self._llm_client = llm_client self._llm_model = llm_model self._style_map = style_map diff --git a/tests/test_markitdown.py b/tests/test_markitdown.py index 76bd302..f2348a1 100644 --- a/tests/test_markitdown.py +++ b/tests/test_markitdown.py @@ -6,6 +6,8 @@ import shutil import pytest import requests +from warnings import catch_warnings, resetwarnings + from markitdown import MarkItDown skip_remote = ( @@ -229,8 +231,43 @@ def test_markitdown_exiftool() -> None: assert target in result.text_content +def test_markitdown_deprecation() -> None: + try: + with catch_warnings(record=True) as w: + test_client = object() + markitdown = MarkItDown(mlm_client=test_client) + assert len(w) == 1 + assert w[0].category is DeprecationWarning + assert markitdown._llm_client == test_client + finally: + resetwarnings() + + try: + with catch_warnings(record=True) as w: + markitdown = MarkItDown(mlm_model="gpt-4o") + assert len(w) == 1 + assert w[0].category is DeprecationWarning + assert markitdown._llm_model == "gpt-4o" + finally: + resetwarnings() + + try: + test_client = object() + markitdown = MarkItDown(mlm_client=test_client, llm_client=test_client) + assert False + except ValueError: + pass + + try: + markitdown = MarkItDown(mlm_model="gpt-4o", llm_model="gpt-4o") + assert False + except ValueError: + pass + + if __name__ == "__main__": """Runs this file's tests from the command line.""" test_markitdown_remote() test_markitdown_local() test_markitdown_exiftool() + test_markitdown_deprecation()