Minimize guesses when guesses are compatible. (#1114)
* Minimize guesses when guesses are compatible.
This commit is contained in:
@@ -15,8 +15,11 @@ from pathlib import Path
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
import requests
|
import requests
|
||||||
|
import magika
|
||||||
|
import charset_normalizer
|
||||||
|
import codecs
|
||||||
|
|
||||||
from ._stream_info import StreamInfo, _guess_stream_info_from_stream
|
from ._stream_info import StreamInfo
|
||||||
|
|
||||||
from .converters import (
|
from .converters import (
|
||||||
PlainTextConverter,
|
PlainTextConverter,
|
||||||
@@ -107,6 +110,8 @@ class MarkItDown:
|
|||||||
else:
|
else:
|
||||||
self._requests_session = requests_session
|
self._requests_session = requests_session
|
||||||
|
|
||||||
|
self._magika = magika.Magika()
|
||||||
|
|
||||||
# TODO - remove these (see enable_builtins)
|
# TODO - remove these (see enable_builtins)
|
||||||
self._llm_client: Any = None
|
self._llm_client: Any = None
|
||||||
self._llm_model: Union[str | None] = None
|
self._llm_model: Union[str | None] = None
|
||||||
@@ -273,33 +278,28 @@ class MarkItDown:
|
|||||||
path = str(path)
|
path = str(path)
|
||||||
|
|
||||||
# Build a base StreamInfo object from which to start guesses
|
# Build a base StreamInfo object from which to start guesses
|
||||||
base_stream_info = StreamInfo(
|
base_guess = StreamInfo(
|
||||||
local_path=path,
|
local_path=path,
|
||||||
extension=os.path.splitext(path)[1],
|
extension=os.path.splitext(path)[1],
|
||||||
filename=os.path.basename(path),
|
filename=os.path.basename(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extend the base_stream_info with any additional info from the arguments
|
# Extend the base_guess with any additional info from the arguments
|
||||||
if stream_info is not None:
|
if stream_info is not None:
|
||||||
base_stream_info = base_stream_info.copy_and_update(stream_info)
|
base_guess = base_guess.copy_and_update(stream_info)
|
||||||
|
|
||||||
if file_extension is not None:
|
if file_extension is not None:
|
||||||
# Deprecated -- use stream_info
|
# Deprecated -- use stream_info
|
||||||
base_stream_info = base_stream_info.copy_and_update(
|
base_guess = base_guess.copy_and_update(extension=file_extension)
|
||||||
extension=file_extension
|
|
||||||
)
|
|
||||||
|
|
||||||
if url is not None:
|
if url is not None:
|
||||||
# Deprecated -- use stream_info
|
# Deprecated -- use stream_info
|
||||||
base_stream_info = base_stream_info.copy_and_update(url=url)
|
base_guess = base_guess.copy_and_update(url=url)
|
||||||
|
|
||||||
with open(path, "rb") as fh:
|
with open(path, "rb") as fh:
|
||||||
# Prepare a list of configurations to try, starting with the base_stream_info
|
guesses = self._get_stream_info_guesses(
|
||||||
guesses: List[StreamInfo] = [base_stream_info]
|
file_stream=fh, base_guess=base_guess
|
||||||
for guess in _guess_stream_info_from_stream(
|
)
|
||||||
file_stream=fh, filename_hint=path
|
|
||||||
):
|
|
||||||
guesses.append(base_stream_info.copy_and_update(guess))
|
|
||||||
return self._convert(file_stream=fh, stream_info_guesses=guesses, **kwargs)
|
return self._convert(file_stream=fh, stream_info_guesses=guesses, **kwargs)
|
||||||
|
|
||||||
def convert_stream(
|
def convert_stream(
|
||||||
@@ -332,21 +332,6 @@ class MarkItDown:
|
|||||||
assert base_guess is not None # for mypy
|
assert base_guess is not None # for mypy
|
||||||
base_guess = base_guess.copy_and_update(url=url)
|
base_guess = base_guess.copy_and_update(url=url)
|
||||||
|
|
||||||
# Append the base guess, if it's non-trivial
|
|
||||||
if base_guess is not None:
|
|
||||||
if base_guess.mimetype is not None or base_guess.extension is not None:
|
|
||||||
guesses.append(base_guess)
|
|
||||||
else:
|
|
||||||
# Create a base guess with no information
|
|
||||||
base_guess = StreamInfo()
|
|
||||||
|
|
||||||
# Create a placeholder filename to help with guessing
|
|
||||||
placeholder_filename = None
|
|
||||||
if base_guess.filename is not None:
|
|
||||||
placeholder_filename = base_guess.filename
|
|
||||||
elif base_guess.extension is not None:
|
|
||||||
placeholder_filename = "placeholder" + base_guess.extension
|
|
||||||
|
|
||||||
# Check if we have a seekable stream. If not, load the entire stream into memory.
|
# Check if we have a seekable stream. If not, load the entire stream into memory.
|
||||||
if not stream.seekable():
|
if not stream.seekable():
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
@@ -359,12 +344,9 @@ class MarkItDown:
|
|||||||
stream = buffer
|
stream = buffer
|
||||||
|
|
||||||
# Add guesses based on stream content
|
# Add guesses based on stream content
|
||||||
for guess in _guess_stream_info_from_stream(
|
guesses = self._get_stream_info_guesses(
|
||||||
file_stream=stream, filename_hint=placeholder_filename
|
file_stream=stream, base_guess=base_guess or StreamInfo()
|
||||||
):
|
)
|
||||||
guesses.append(base_guess.copy_and_update(guess))
|
|
||||||
|
|
||||||
# Perform the conversion
|
|
||||||
return self._convert(file_stream=stream, stream_info_guesses=guesses, **kwargs)
|
return self._convert(file_stream=stream, stream_info_guesses=guesses, **kwargs)
|
||||||
|
|
||||||
def convert_url(
|
def convert_url(
|
||||||
@@ -435,31 +417,16 @@ class MarkItDown:
|
|||||||
# Deprecated -- use stream_info
|
# Deprecated -- use stream_info
|
||||||
base_guess = base_guess.copy_and_update(url=url)
|
base_guess = base_guess.copy_and_update(url=url)
|
||||||
|
|
||||||
# Add the guess if its non-trivial
|
|
||||||
guesses: List[StreamInfo] = []
|
|
||||||
if base_guess.mimetype is not None or base_guess.extension is not None:
|
|
||||||
guesses.append(base_guess)
|
|
||||||
|
|
||||||
# Read into BytesIO
|
# Read into BytesIO
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
for chunk in response.iter_content(chunk_size=512):
|
for chunk in response.iter_content(chunk_size=512):
|
||||||
buffer.write(chunk)
|
buffer.write(chunk)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
||||||
# Create a placeholder filename to help with guessing
|
|
||||||
placeholder_filename = None
|
|
||||||
if base_guess.filename is not None:
|
|
||||||
placeholder_filename = base_guess.filename
|
|
||||||
elif base_guess.extension is not None:
|
|
||||||
placeholder_filename = "placeholder" + base_guess.extension
|
|
||||||
|
|
||||||
# Add guesses based on stream content
|
|
||||||
for guess in _guess_stream_info_from_stream(
|
|
||||||
file_stream=buffer, filename_hint=placeholder_filename
|
|
||||||
):
|
|
||||||
guesses.append(base_guess.copy_and_update(guess))
|
|
||||||
|
|
||||||
# Convert
|
# Convert
|
||||||
|
guesses = self._get_stream_info_guesses(
|
||||||
|
file_stream=buffer, base_guess=base_guess
|
||||||
|
)
|
||||||
return self._convert(file_stream=buffer, stream_info_guesses=guesses, **kwargs)
|
return self._convert(file_stream=buffer, stream_info_guesses=guesses, **kwargs)
|
||||||
|
|
||||||
def _convert(
|
def _convert(
|
||||||
@@ -593,3 +560,94 @@ class MarkItDown:
|
|||||||
self._converters.insert(
|
self._converters.insert(
|
||||||
0, ConverterRegistration(converter=converter, priority=priority)
|
0, ConverterRegistration(converter=converter, priority=priority)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_stream_info_guesses(
|
||||||
|
self, file_stream: BinaryIO, base_guess: StreamInfo
|
||||||
|
) -> List[StreamInfo]:
|
||||||
|
"""
|
||||||
|
Given a base guess, attempt to guess or expand on the stream info using the stream content (via magika).
|
||||||
|
"""
|
||||||
|
guesses: List[StreamInfo] = []
|
||||||
|
|
||||||
|
# Call magika to guess from the stream
|
||||||
|
cur_pos = file_stream.tell()
|
||||||
|
try:
|
||||||
|
stream_bytes = file_stream.read()
|
||||||
|
|
||||||
|
result = self._magika.identify_bytes(stream_bytes)
|
||||||
|
if result.status == "ok" and result.prediction.output.label != "unknown":
|
||||||
|
# If it's text, also guess the charset
|
||||||
|
charset = None
|
||||||
|
if result.prediction.output.is_text:
|
||||||
|
charset_result = charset_normalizer.from_bytes(stream_bytes).best()
|
||||||
|
if charset_result is not None:
|
||||||
|
charset = self._normalize_charset(charset_result.encoding)
|
||||||
|
|
||||||
|
# Normalize the first extension listed
|
||||||
|
guessed_extension = None
|
||||||
|
if len(result.prediction.output.extensions) > 0:
|
||||||
|
guessed_extension = "." + result.prediction.output.extensions[0]
|
||||||
|
|
||||||
|
# Determine if the guess is compatible with the base guess
|
||||||
|
compatible = True
|
||||||
|
if (
|
||||||
|
base_guess.mimetype is not None
|
||||||
|
and base_guess.mimetype != result.prediction.output.mime_type
|
||||||
|
):
|
||||||
|
compatible = False
|
||||||
|
|
||||||
|
if (
|
||||||
|
base_guess.extension is not None
|
||||||
|
and base_guess.extension.lstrip(".")
|
||||||
|
not in result.prediction.output.extensions
|
||||||
|
):
|
||||||
|
compatible = False
|
||||||
|
|
||||||
|
if (
|
||||||
|
base_guess.charset is not None
|
||||||
|
and self._normalize_charset(base_guess.charset) != charset
|
||||||
|
):
|
||||||
|
compatible = False
|
||||||
|
|
||||||
|
if compatible:
|
||||||
|
# Add the compatible base guess
|
||||||
|
guesses.append(
|
||||||
|
StreamInfo(
|
||||||
|
mimetype=base_guess.mimetype
|
||||||
|
or result.prediction.output.mime_type,
|
||||||
|
extension=base_guess.extension or guessed_extension,
|
||||||
|
charset=base_guess.charset or charset,
|
||||||
|
filename=base_guess.filename,
|
||||||
|
local_path=base_guess.local_path,
|
||||||
|
url=base_guess.url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# The magika guess was incompatible with the base guess, so add both guesses
|
||||||
|
guesses.append(base_guess)
|
||||||
|
guesses.append(
|
||||||
|
StreamInfo(
|
||||||
|
mimetype=result.prediction.output.mime_type,
|
||||||
|
extension=guessed_extension,
|
||||||
|
charset=charset,
|
||||||
|
filename=base_guess.filename,
|
||||||
|
local_path=base_guess.local_path,
|
||||||
|
url=base_guess.url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# There were no other guesses, so just add the base guess
|
||||||
|
guesses.append(base_guess)
|
||||||
|
finally:
|
||||||
|
file_stream.seek(cur_pos)
|
||||||
|
|
||||||
|
return guesses
|
||||||
|
|
||||||
|
def _normalize_charset(self, charset: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize a charset string to a canonical form.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return codecs.lookup(charset).name
|
||||||
|
except LookupError:
|
||||||
|
return charset
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
import mimetypes
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from typing import Optional, BinaryIO, List, TypeVar, Type
|
from typing import Optional
|
||||||
from magika import Magika
|
|
||||||
|
|
||||||
magika = Magika()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True, frozen=True)
|
@dataclass(kw_only=True, frozen=True)
|
||||||
@@ -35,58 +30,3 @@ class StreamInfo:
|
|||||||
new_info.update(kwargs)
|
new_info.update(kwargs)
|
||||||
|
|
||||||
return StreamInfo(**new_info)
|
return StreamInfo(**new_info)
|
||||||
|
|
||||||
|
|
||||||
# Behavior subject to change.
|
|
||||||
# Do not rely on this outside of this module.
|
|
||||||
def _guess_stream_info_from_stream(
|
|
||||||
file_stream: BinaryIO,
|
|
||||||
*,
|
|
||||||
filename_hint: Optional[str] = None,
|
|
||||||
) -> List[StreamInfo]:
|
|
||||||
"""
|
|
||||||
Guess StreamInfo properties (mostly mimetype and extension) from a stream.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- stream: The stream to guess the StreamInfo from.
|
|
||||||
- filename_hint [Optional]: A filename hint to help with the guessing (may be a placeholder, and not actually be the file name)
|
|
||||||
|
|
||||||
Returns a list of StreamInfo objects in order of confidence.
|
|
||||||
"""
|
|
||||||
guesses: List[StreamInfo] = []
|
|
||||||
|
|
||||||
# Call magika to guess from the stream
|
|
||||||
cur_pos = file_stream.tell()
|
|
||||||
try:
|
|
||||||
result = magika.identify_bytes(file_stream.read())
|
|
||||||
if result.status == "ok" and result.prediction.output.label != "unknown":
|
|
||||||
extension = None
|
|
||||||
if len(result.prediction.output.extensions) > 0:
|
|
||||||
extension = result.prediction.output.extensions[0]
|
|
||||||
if extension and not extension.startswith("."):
|
|
||||||
extension = "." + extension
|
|
||||||
guesses.append(
|
|
||||||
StreamInfo(
|
|
||||||
mimetype=result.prediction.output.mime_type,
|
|
||||||
extension=extension,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
file_stream.seek(cur_pos)
|
|
||||||
|
|
||||||
# Add a guess purely based on the filename hint
|
|
||||||
if filename_hint:
|
|
||||||
try:
|
|
||||||
# Requires Python 3.13+
|
|
||||||
mimetype, _ = mimetypes.guess_file_type(filename_hint) # type: ignore
|
|
||||||
except AttributeError:
|
|
||||||
mimetype, _ = mimetypes.guess_type(filename_hint)
|
|
||||||
|
|
||||||
if mimetype:
|
|
||||||
guesses.append(
|
|
||||||
StreamInfo(
|
|
||||||
mimetype=mimetype, extension=os.path.splitext(filename_hint)[1]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return guesses
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from markitdown import (
|
|||||||
FileConversionException,
|
FileConversionException,
|
||||||
StreamInfo,
|
StreamInfo,
|
||||||
)
|
)
|
||||||
from markitdown._stream_info import _guess_stream_info_from_stream
|
|
||||||
|
|
||||||
skip_remote = (
|
skip_remote = (
|
||||||
True if os.environ.get("GITHUB_ACTIONS") else False
|
True if os.environ.get("GITHUB_ACTIONS") else False
|
||||||
@@ -265,10 +264,16 @@ def test_stream_info_guesses() -> None:
|
|||||||
(os.path.join(TEST_FILES_DIR, "test.xls"), "application/vnd.ms-excel"),
|
(os.path.join(TEST_FILES_DIR, "test.xls"), "application/vnd.ms-excel"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
markitdown = MarkItDown()
|
||||||
for file_path, expected_mimetype in test_tuples:
|
for file_path, expected_mimetype in test_tuples:
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
guesses = _guess_stream_info_from_stream(
|
guesses = markitdown._get_stream_info_guesses(
|
||||||
f, filename_hint=os.path.basename(file_path)
|
f,
|
||||||
|
StreamInfo(
|
||||||
|
filename=os.path.basename(file_path),
|
||||||
|
local_path=file_path,
|
||||||
|
extension=os.path.splitext(file_path)[1],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert len(guesses) > 0
|
assert len(guesses) > 0
|
||||||
assert guesses[0].mimetype == expected_mimetype
|
assert guesses[0].mimetype == expected_mimetype
|
||||||
|
|||||||
Reference in New Issue
Block a user