Minimize guesses when guesses are compatible. (#1114)
* Minimize guesses when guesses are compatible.
This commit is contained in:
@@ -15,8 +15,11 @@ from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from warnings import warn
|
||||
import requests
|
||||
import magika
|
||||
import charset_normalizer
|
||||
import codecs
|
||||
|
||||
from ._stream_info import StreamInfo, _guess_stream_info_from_stream
|
||||
from ._stream_info import StreamInfo
|
||||
|
||||
from .converters import (
|
||||
PlainTextConverter,
|
||||
@@ -107,6 +110,8 @@ class MarkItDown:
|
||||
else:
|
||||
self._requests_session = requests_session
|
||||
|
||||
self._magika = magika.Magika()
|
||||
|
||||
# TODO - remove these (see enable_builtins)
|
||||
self._llm_client: Any = None
|
||||
self._llm_model: Union[str | None] = None
|
||||
@@ -273,33 +278,28 @@ class MarkItDown:
|
||||
path = str(path)
|
||||
|
||||
# Build a base StreamInfo object from which to start guesses
|
||||
base_stream_info = StreamInfo(
|
||||
base_guess = StreamInfo(
|
||||
local_path=path,
|
||||
extension=os.path.splitext(path)[1],
|
||||
filename=os.path.basename(path),
|
||||
)
|
||||
|
||||
# Extend the base_stream_info with any additional info from the arguments
|
||||
# Extend the base_guess with any additional info from the arguments
|
||||
if stream_info is not None:
|
||||
base_stream_info = base_stream_info.copy_and_update(stream_info)
|
||||
base_guess = base_guess.copy_and_update(stream_info)
|
||||
|
||||
if file_extension is not None:
|
||||
# Deprecated -- use stream_info
|
||||
base_stream_info = base_stream_info.copy_and_update(
|
||||
extension=file_extension
|
||||
)
|
||||
base_guess = base_guess.copy_and_update(extension=file_extension)
|
||||
|
||||
if url is not None:
|
||||
# Deprecated -- use stream_info
|
||||
base_stream_info = base_stream_info.copy_and_update(url=url)
|
||||
base_guess = base_guess.copy_and_update(url=url)
|
||||
|
||||
with open(path, "rb") as fh:
|
||||
# Prepare a list of configurations to try, starting with the base_stream_info
|
||||
guesses: List[StreamInfo] = [base_stream_info]
|
||||
for guess in _guess_stream_info_from_stream(
|
||||
file_stream=fh, filename_hint=path
|
||||
):
|
||||
guesses.append(base_stream_info.copy_and_update(guess))
|
||||
guesses = self._get_stream_info_guesses(
|
||||
file_stream=fh, base_guess=base_guess
|
||||
)
|
||||
return self._convert(file_stream=fh, stream_info_guesses=guesses, **kwargs)
|
||||
|
||||
def convert_stream(
|
||||
@@ -332,21 +332,6 @@ class MarkItDown:
|
||||
assert base_guess is not None # for mypy
|
||||
base_guess = base_guess.copy_and_update(url=url)
|
||||
|
||||
# Append the base guess, if it's non-trivial
|
||||
if base_guess is not None:
|
||||
if base_guess.mimetype is not None or base_guess.extension is not None:
|
||||
guesses.append(base_guess)
|
||||
else:
|
||||
# Create a base guess with no information
|
||||
base_guess = StreamInfo()
|
||||
|
||||
# Create a placeholder filename to help with guessing
|
||||
placeholder_filename = None
|
||||
if base_guess.filename is not None:
|
||||
placeholder_filename = base_guess.filename
|
||||
elif base_guess.extension is not None:
|
||||
placeholder_filename = "placeholder" + base_guess.extension
|
||||
|
||||
# Check if we have a seekable stream. If not, load the entire stream into memory.
|
||||
if not stream.seekable():
|
||||
buffer = io.BytesIO()
|
||||
@@ -359,12 +344,9 @@ class MarkItDown:
|
||||
stream = buffer
|
||||
|
||||
# Add guesses based on stream content
|
||||
for guess in _guess_stream_info_from_stream(
|
||||
file_stream=stream, filename_hint=placeholder_filename
|
||||
):
|
||||
guesses.append(base_guess.copy_and_update(guess))
|
||||
|
||||
# Perform the conversion
|
||||
guesses = self._get_stream_info_guesses(
|
||||
file_stream=stream, base_guess=base_guess or StreamInfo()
|
||||
)
|
||||
return self._convert(file_stream=stream, stream_info_guesses=guesses, **kwargs)
|
||||
|
||||
def convert_url(
|
||||
@@ -435,31 +417,16 @@ class MarkItDown:
|
||||
# Deprecated -- use stream_info
|
||||
base_guess = base_guess.copy_and_update(url=url)
|
||||
|
||||
# Add the guess if its non-trivial
|
||||
guesses: List[StreamInfo] = []
|
||||
if base_guess.mimetype is not None or base_guess.extension is not None:
|
||||
guesses.append(base_guess)
|
||||
|
||||
# Read into BytesIO
|
||||
buffer = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=512):
|
||||
buffer.write(chunk)
|
||||
buffer.seek(0)
|
||||
|
||||
# Create a placeholder filename to help with guessing
|
||||
placeholder_filename = None
|
||||
if base_guess.filename is not None:
|
||||
placeholder_filename = base_guess.filename
|
||||
elif base_guess.extension is not None:
|
||||
placeholder_filename = "placeholder" + base_guess.extension
|
||||
|
||||
# Add guesses based on stream content
|
||||
for guess in _guess_stream_info_from_stream(
|
||||
file_stream=buffer, filename_hint=placeholder_filename
|
||||
):
|
||||
guesses.append(base_guess.copy_and_update(guess))
|
||||
|
||||
# Convert
|
||||
guesses = self._get_stream_info_guesses(
|
||||
file_stream=buffer, base_guess=base_guess
|
||||
)
|
||||
return self._convert(file_stream=buffer, stream_info_guesses=guesses, **kwargs)
|
||||
|
||||
def _convert(
|
||||
@@ -593,3 +560,94 @@ class MarkItDown:
|
||||
self._converters.insert(
|
||||
0, ConverterRegistration(converter=converter, priority=priority)
|
||||
)
|
||||
|
||||
def _get_stream_info_guesses(
|
||||
self, file_stream: BinaryIO, base_guess: StreamInfo
|
||||
) -> List[StreamInfo]:
|
||||
"""
|
||||
Given a base guess, attempt to guess or expand on the stream info using the stream content (via magika).
|
||||
"""
|
||||
guesses: List[StreamInfo] = []
|
||||
|
||||
# Call magika to guess from the stream
|
||||
cur_pos = file_stream.tell()
|
||||
try:
|
||||
stream_bytes = file_stream.read()
|
||||
|
||||
result = self._magika.identify_bytes(stream_bytes)
|
||||
if result.status == "ok" and result.prediction.output.label != "unknown":
|
||||
# If it's text, also guess the charset
|
||||
charset = None
|
||||
if result.prediction.output.is_text:
|
||||
charset_result = charset_normalizer.from_bytes(stream_bytes).best()
|
||||
if charset_result is not None:
|
||||
charset = self._normalize_charset(charset_result.encoding)
|
||||
|
||||
# Normalize the first extension listed
|
||||
guessed_extension = None
|
||||
if len(result.prediction.output.extensions) > 0:
|
||||
guessed_extension = "." + result.prediction.output.extensions[0]
|
||||
|
||||
# Determine if the guess is compatible with the base guess
|
||||
compatible = True
|
||||
if (
|
||||
base_guess.mimetype is not None
|
||||
and base_guess.mimetype != result.prediction.output.mime_type
|
||||
):
|
||||
compatible = False
|
||||
|
||||
if (
|
||||
base_guess.extension is not None
|
||||
and base_guess.extension.lstrip(".")
|
||||
not in result.prediction.output.extensions
|
||||
):
|
||||
compatible = False
|
||||
|
||||
if (
|
||||
base_guess.charset is not None
|
||||
and self._normalize_charset(base_guess.charset) != charset
|
||||
):
|
||||
compatible = False
|
||||
|
||||
if compatible:
|
||||
# Add the compatible base guess
|
||||
guesses.append(
|
||||
StreamInfo(
|
||||
mimetype=base_guess.mimetype
|
||||
or result.prediction.output.mime_type,
|
||||
extension=base_guess.extension or guessed_extension,
|
||||
charset=base_guess.charset or charset,
|
||||
filename=base_guess.filename,
|
||||
local_path=base_guess.local_path,
|
||||
url=base_guess.url,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# The magika guess was incompatible with the base guess, so add both guesses
|
||||
guesses.append(base_guess)
|
||||
guesses.append(
|
||||
StreamInfo(
|
||||
mimetype=result.prediction.output.mime_type,
|
||||
extension=guessed_extension,
|
||||
charset=charset,
|
||||
filename=base_guess.filename,
|
||||
local_path=base_guess.local_path,
|
||||
url=base_guess.url,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There were no other guesses, so just add the base guess
|
||||
guesses.append(base_guess)
|
||||
finally:
|
||||
file_stream.seek(cur_pos)
|
||||
|
||||
return guesses
|
||||
|
||||
def _normalize_charset(self, charset: str) -> str:
|
||||
"""
|
||||
Normalize a charset string to a canonical form.
|
||||
"""
|
||||
try:
|
||||
return codecs.lookup(charset).name
|
||||
except LookupError:
|
||||
return charset
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import mimetypes
|
||||
import os
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Optional, BinaryIO, List, TypeVar, Type
|
||||
from magika import Magika
|
||||
|
||||
magika = Magika()
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
@@ -35,58 +30,3 @@ class StreamInfo:
|
||||
new_info.update(kwargs)
|
||||
|
||||
return StreamInfo(**new_info)
|
||||
|
||||
|
||||
# Behavior subject to change.
|
||||
# Do not rely on this outside of this module.
|
||||
def _guess_stream_info_from_stream(
|
||||
file_stream: BinaryIO,
|
||||
*,
|
||||
filename_hint: Optional[str] = None,
|
||||
) -> List[StreamInfo]:
|
||||
"""
|
||||
Guess StreamInfo properties (mostly mimetype and extension) from a stream.
|
||||
|
||||
Args:
|
||||
- stream: The stream to guess the StreamInfo from.
|
||||
- filename_hint [Optional]: A filename hint to help with the guessing (may be a placeholder, and not actually be the file name)
|
||||
|
||||
Returns a list of StreamInfo objects in order of confidence.
|
||||
"""
|
||||
guesses: List[StreamInfo] = []
|
||||
|
||||
# Call magika to guess from the stream
|
||||
cur_pos = file_stream.tell()
|
||||
try:
|
||||
result = magika.identify_bytes(file_stream.read())
|
||||
if result.status == "ok" and result.prediction.output.label != "unknown":
|
||||
extension = None
|
||||
if len(result.prediction.output.extensions) > 0:
|
||||
extension = result.prediction.output.extensions[0]
|
||||
if extension and not extension.startswith("."):
|
||||
extension = "." + extension
|
||||
guesses.append(
|
||||
StreamInfo(
|
||||
mimetype=result.prediction.output.mime_type,
|
||||
extension=extension,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
file_stream.seek(cur_pos)
|
||||
|
||||
# Add a guess purely based on the filename hint
|
||||
if filename_hint:
|
||||
try:
|
||||
# Requires Python 3.13+
|
||||
mimetype, _ = mimetypes.guess_file_type(filename_hint) # type: ignore
|
||||
except AttributeError:
|
||||
mimetype, _ = mimetypes.guess_type(filename_hint)
|
||||
|
||||
if mimetype:
|
||||
guesses.append(
|
||||
StreamInfo(
|
||||
mimetype=mimetype, extension=os.path.splitext(filename_hint)[1]
|
||||
)
|
||||
)
|
||||
|
||||
return guesses
|
||||
|
||||
@@ -13,7 +13,6 @@ from markitdown import (
|
||||
FileConversionException,
|
||||
StreamInfo,
|
||||
)
|
||||
from markitdown._stream_info import _guess_stream_info_from_stream
|
||||
|
||||
skip_remote = (
|
||||
True if os.environ.get("GITHUB_ACTIONS") else False
|
||||
@@ -265,10 +264,16 @@ def test_stream_info_guesses() -> None:
|
||||
(os.path.join(TEST_FILES_DIR, "test.xls"), "application/vnd.ms-excel"),
|
||||
]
|
||||
|
||||
markitdown = MarkItDown()
|
||||
for file_path, expected_mimetype in test_tuples:
|
||||
with open(file_path, "rb") as f:
|
||||
guesses = _guess_stream_info_from_stream(
|
||||
f, filename_hint=os.path.basename(file_path)
|
||||
guesses = markitdown._get_stream_info_guesses(
|
||||
f,
|
||||
StreamInfo(
|
||||
filename=os.path.basename(file_path),
|
||||
local_path=file_path,
|
||||
extension=os.path.splitext(file_path)[1],
|
||||
),
|
||||
)
|
||||
assert len(guesses) > 0
|
||||
assert guesses[0].mimetype == expected_mimetype
|
||||
|
||||
Reference in New Issue
Block a user