From e928b43afb0d57e1e71e98a9a2b4832aae46a0cc Mon Sep 17 00:00:00 2001 From: afourney Date: Mon, 24 Mar 2025 21:43:04 -0700 Subject: [PATCH] convert_url renamed to convert_uri, and now handles data and file URIs (#1153) --- .../markitdown/src/markitdown/_markitdown.py | 90 +++++++++++++++---- .../markitdown/src/markitdown/_uri_utils.py | 52 +++++++++++ packages/markitdown/tests/test_module_misc.py | 77 ++++++++++++++++ .../markitdown/tests/test_module_vectors.py | 54 +++++++++-- 4 files changed, 251 insertions(+), 22 deletions(-) create mode 100644 packages/markitdown/src/markitdown/_uri_utils.py diff --git a/packages/markitdown/src/markitdown/_markitdown.py b/packages/markitdown/src/markitdown/_markitdown.py index a8f7c9e..8f58db4 100644 --- a/packages/markitdown/src/markitdown/_markitdown.py +++ b/packages/markitdown/src/markitdown/_markitdown.py @@ -20,6 +20,7 @@ import charset_normalizer import codecs from ._stream_info import StreamInfo +from ._uri_utils import parse_data_uri, file_uri_to_path from .converters import ( PlainTextConverter, @@ -242,9 +243,10 @@ class MarkItDown: # Local path or url if isinstance(source, str): if ( - source.startswith("http://") - or source.startswith("https://") - or source.startswith("file://") + source.startswith("http:") + or source.startswith("https:") + or source.startswith("file:") + or source.startswith("data:") ): # Rename the url argument to mock_url # (Deprecated -- use stream_info) @@ -253,7 +255,7 @@ class MarkItDown: _kwargs["mock_url"] = _kwargs["url"] del _kwargs["url"] - return self.convert_url(source, stream_info=stream_info, **_kwargs) + return self.convert_uri(source, stream_info=stream_info, **_kwargs) else: return self.convert_local(source, stream_info=stream_info, **kwargs) # Path object @@ -363,22 +365,80 @@ class MarkItDown: url: str, *, stream_info: Optional[StreamInfo] = None, + file_extension: Optional[str] = None, + mock_url: Optional[str] = None, + **kwargs: Any, + ) -> DocumentConverterResult: + """Alias for convert_uri()""" + # convert_url will likely be deprecated in the future in favor of convert_uri + return self.convert_uri( + url, + stream_info=stream_info, + file_extension=file_extension, + mock_url=mock_url, + **kwargs, + ) + + def convert_uri( + self, + uri: str, + *, + stream_info: Optional[StreamInfo] = None, file_extension: Optional[str] = None, # Deprecated -- use stream_info mock_url: Optional[ str ] = None, # Mock the request as if it came from a different URL **kwargs: Any, - ) -> DocumentConverterResult: # TODO: fix kwargs type - # Send a HTTP request to the URL - response = self._requests_session.get(url, stream=True) - response.raise_for_status() - return self.convert_response( - response, - stream_info=stream_info, - file_extension=file_extension, - url=mock_url, - **kwargs, - ) + ) -> DocumentConverterResult: + uri = uri.strip() + + # File URIs + if uri.startswith("file:"): + netloc, path = file_uri_to_path(uri) + if netloc and netloc != "localhost": + raise ValueError( + f"Unsupported file URI: {uri}. Netloc must be empty or localhost." + ) + return self.convert_local( + path, + stream_info=stream_info, + file_extension=file_extension, + url=mock_url, + **kwargs, + ) + # Data URIs + elif uri.startswith("data:"): + mimetype, attributes, data = parse_data_uri(uri) + + base_guess = StreamInfo( + mimetype=mimetype, + charset=attributes.get("charset"), + ) + if stream_info is not None: + base_guess = base_guess.copy_and_update(stream_info) + + return self.convert_stream( + io.BytesIO(data), + stream_info=base_guess, + file_extension=file_extension, + url=mock_url, + **kwargs, + ) + # HTTP/HTTPS URIs + elif uri.startswith("http:") or uri.startswith("https:"): + response = self._requests_session.get(uri, stream=True) + response.raise_for_status() + return self.convert_response( + response, + stream_info=stream_info, + file_extension=file_extension, + url=mock_url, + **kwargs, + ) + else: + raise ValueError( + f"Unsupported URI scheme: {uri.split(':')[0]}. Supported schemes are: file:, data:, http:, https:" + ) def convert_response( self, diff --git a/packages/markitdown/src/markitdown/_uri_utils.py b/packages/markitdown/src/markitdown/_uri_utils.py new file mode 100644 index 0000000..603da63 --- /dev/null +++ b/packages/markitdown/src/markitdown/_uri_utils.py @@ -0,0 +1,52 @@ +import base64 +import os +from typing import Tuple, Dict +from urllib.request import url2pathname +from urllib.parse import urlparse, unquote_to_bytes + + +def file_uri_to_path(file_uri: str) -> Tuple[str | None, str]: + """Convert a file URI to a local file path""" + parsed = urlparse(file_uri) + if parsed.scheme != "file": + raise ValueError(f"Not a file URL: {file_uri}") + + netloc = parsed.netloc if parsed.netloc else None + path = os.path.abspath(url2pathname(parsed.path)) + return netloc, path + + +def parse_data_uri(uri: str) -> Tuple[str | None, Dict[str, str], bytes]: + if not uri.startswith("data:"): + raise ValueError("Not a data URI") + + header, _, data = uri.partition(",") + if not _: + raise ValueError("Malformed data URI, missing ',' separator") + + meta = header[5:] # Strip 'data:' + parts = meta.split(";") + + is_base64 = False + # Ends with base64? + if parts[-1] == "base64": + parts.pop() + is_base64 = True + + mime_type = None # Normally this would default to text/plain but we won't assume + if len(parts) and len(parts[0]) > 0: + # First part is the mime type + mime_type = parts.pop(0) + + attributes: Dict[str, str] = {} + for part in parts: + # Handle key=value pairs in the middle + if "=" in part: + key, value = part.split("=", 1) + attributes[key] = value + elif len(part) > 0: + attributes[part] = "" + + content = base64.b64decode(data) if is_base64 else unquote_to_bytes(data) + + return mime_type, attributes, content diff --git a/packages/markitdown/tests/test_module_misc.py b/packages/markitdown/tests/test_module_misc.py index 4079107..33e2c44 100644 --- a/packages/markitdown/tests/test_module_misc.py +++ b/packages/markitdown/tests/test_module_misc.py @@ -5,6 +5,8 @@ import shutil import openai import pytest +from markitdown._uri_utils import parse_data_uri, file_uri_to_path + from markitdown import ( MarkItDown, UnsupportedFormatException, @@ -176,6 +178,79 @@ def test_stream_info_operations() -> None: assert updated_stream_info.url == "url.1" +def test_data_uris() -> None: + # Test basic parsing of data URIs + data_uri = "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==" + mime_type, attributes, data = parse_data_uri(data_uri) + assert mime_type == "text/plain" + assert len(attributes) == 0 + assert data == b"Hello, World!" + + data_uri = "data:base64,SGVsbG8sIFdvcmxkIQ==" + mime_type, attributes, data = parse_data_uri(data_uri) + assert mime_type is None + assert len(attributes) == 0 + assert data == b"Hello, World!" + + data_uri = "data:text/plain;charset=utf-8;base64,SGVsbG8sIFdvcmxkIQ==" + mime_type, attributes, data = parse_data_uri(data_uri) + assert mime_type == "text/plain" + assert len(attributes) == 1 + assert attributes["charset"] == "utf-8" + assert data == b"Hello, World!" + + data_uri = "data:,Hello%2C%20World%21" + mime_type, attributes, data = parse_data_uri(data_uri) + assert mime_type is None + assert len(attributes) == 0 + assert data == b"Hello, World!" + + data_uri = "data:text/plain,Hello%2C%20World%21" + mime_type, attributes, data = parse_data_uri(data_uri) + assert mime_type == "text/plain" + assert len(attributes) == 0 + assert data == b"Hello, World!" + + data_uri = "data:text/plain;charset=utf-8,Hello%2C%20World%21" + mime_type, attributes, data = parse_data_uri(data_uri) + assert mime_type == "text/plain" + assert len(attributes) == 1 + assert attributes["charset"] == "utf-8" + assert data == b"Hello, World!" + + +def test_file_uris() -> None: + # Test file URI with an empty host + file_uri = "file:///path/to/file.txt" + netloc, path = file_uri_to_path(file_uri) + assert netloc is None + assert path == "/path/to/file.txt" + + # Test file URI with no host + file_uri = "file:/path/to/file.txt" + netloc, path = file_uri_to_path(file_uri) + assert netloc is None + assert path == "/path/to/file.txt" + + # Test file URI with localhost + file_uri = "file://localhost/path/to/file.txt" + netloc, path = file_uri_to_path(file_uri) + assert netloc == "localhost" + assert path == "/path/to/file.txt" + + # Test file URI with query parameters + file_uri = "file:///path/to/file.txt?param=value" + netloc, path = file_uri_to_path(file_uri) + assert netloc is None + assert path == "/path/to/file.txt" + + # Test file URI with fragment + file_uri = "file:///path/to/file.txt#fragment" + netloc, path = file_uri_to_path(file_uri) + assert netloc is None + assert path == "/path/to/file.txt" + + def test_docx_comments() -> None: markitdown = MarkItDown() @@ -314,6 +389,8 @@ if __name__ == "__main__": """Runs this file's tests from the command line.""" for test in [ test_stream_info_operations, + test_data_uris, + test_file_uris, test_docx_comments, test_input_as_strings, test_markitdown_remote, diff --git a/packages/markitdown/tests/test_module_vectors.py b/packages/markitdown/tests/test_module_vectors.py index 09e4a2b..98fd0c7 100644 --- a/packages/markitdown/tests/test_module_vectors.py +++ b/packages/markitdown/tests/test_module_vectors.py @@ -3,7 +3,9 @@ import os import time import pytest import codecs +import base64 +from pathlib import Path if __name__ == "__main__": from _test_vectors import GENERAL_TEST_VECTORS, DATA_URI_TEST_VECTORS @@ -108,8 +110,8 @@ def test_convert_stream_without_hints(test_vector): reason="do not run tests that query external urls", ) @pytest.mark.parametrize("test_vector", GENERAL_TEST_VECTORS) -def test_convert_url(test_vector): - """Test the conversion of a stream with no stream info.""" +def test_convert_http_uri(test_vector): + """Test the conversion of an HTTP:// or HTTPS:// URI.""" markitdown = MarkItDown() time.sleep(1) # Ensure we don't hit rate limits @@ -124,8 +126,44 @@ def test_convert_url(test_vector): assert string not in result.markdown +@pytest.mark.parametrize("test_vector", GENERAL_TEST_VECTORS) +def test_convert_file_uri(test_vector): + """Test the conversion of a file:// URI.""" + markitdown = MarkItDown() + + result = markitdown.convert( + Path(os.path.join(TEST_FILES_DIR, test_vector.filename)).as_uri(), + url=test_vector.url, + ) + for string in test_vector.must_include: + assert string in result.markdown + for string in test_vector.must_not_include: + assert string not in result.markdown + + +@pytest.mark.parametrize("test_vector", GENERAL_TEST_VECTORS) +def test_convert_data_uri(test_vector): + """Test the conversion of a data URI.""" + markitdown = MarkItDown() + + data = "" + with open(os.path.join(TEST_FILES_DIR, test_vector.filename), "rb") as stream: + data = base64.b64encode(stream.read()).decode("utf-8") + mimetype = test_vector.mimetype + data_uri = f"data:{mimetype};base64,{data}" + + result = markitdown.convert( + data_uri, + url=test_vector.url, + ) + for string in test_vector.must_include: + assert string in result.markdown + for string in test_vector.must_not_include: + assert string not in result.markdown + + @pytest.mark.parametrize("test_vector", DATA_URI_TEST_VECTORS) -def test_convert_with_data_uris(test_vector): +def test_convert_keep_data_uris(test_vector): """Test API functionality when keep_data_uris is enabled""" markitdown = MarkItDown() @@ -143,7 +181,7 @@ def test_convert_with_data_uris(test_vector): @pytest.mark.parametrize("test_vector", DATA_URI_TEST_VECTORS) -def test_convert_stream_with_data_uris(test_vector): +def test_convert_stream_keep_data_uris(test_vector): """Test the conversion of a stream with no stream info.""" markitdown = MarkItDown() @@ -175,7 +213,9 @@ if __name__ == "__main__": test_convert_local, test_convert_stream_with_hints, test_convert_stream_without_hints, - test_convert_url, + test_convert_http_uri, + test_convert_file_uri, + test_convert_data_uri, ]: for test_vector in GENERAL_TEST_VECTORS: print( @@ -186,8 +226,8 @@ if __name__ == "__main__": # Data URI tests for test_function in [ - test_convert_with_data_uris, - test_convert_stream_with_data_uris, + test_convert_keep_data_uris, + test_convert_stream_keep_data_uris, ]: for test_vector in DATA_URI_TEST_VECTORS: print(