Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions ankihub/ankihub_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Set

from anki.models import NotetypeId
from mashumaro import field_options
from mashumaro.config import BaseConfig
from mashumaro.mixins.json import DataClassJSONMixin
Expand Down Expand Up @@ -317,9 +318,12 @@ class DailyCardReviewSummary(DataClassJSONMixinWithConfig):


def get_media_names_from_notes_data(notes_data: Sequence[NoteInfo]) -> Set[str]:
"""Return the names of all media files on the given notes.
"""Return the names of all media files on the given notes and their note types.
Only returns names of local files, not remote files."""
return {name for note in notes_data for name in get_media_names_from_note_info(note)}
note_refs = {name for note in notes_data for name in get_media_names_from_note_info(note)}
note_type_ids = {note.mid for note in notes_data}
note_type_refs = {name for note_type_id in note_type_ids for name in get_media_names_from_notetype(note_type_id)}
return {*note_refs, *note_type_refs}


def get_media_names_from_suggestions(suggestions: Sequence[NoteSuggestion]) -> Set[str]:
Expand All @@ -338,6 +342,18 @@ def get_media_names_from_note_info(note_info: NoteInfo) -> Set[str]:
return result


def get_media_names_from_notetype(notetype_id: int) -> Set[str]:
import aqt

refs = set()
note_type = aqt.mw.col.models.get(NotetypeId(notetype_id))
refs.update(local_media_names_from_html(note_type["css"]))
for tmpl in note_type["tmpls"]:
refs.update(local_media_names_from_html(tmpl["qfmt"]))
refs.update(local_media_names_from_html(tmpl["afmt"]))
return refs


def _get_media_names_from_field(field: Field) -> Set[str]:
"""Return the names of all media files on the given field. Only returns names of local files, not remote files."""
result = local_media_names_from_html(field.value)
Expand Down
12 changes: 9 additions & 3 deletions ankihub/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@

# Regex to find the name of image files inside an <img> tag in HTML
# excluding the ones that start with http:// or https://
IMG_NAME_IN_IMG_TAG_REGEX = re.compile(r"<img.*?src=[\"'](?!http://|https://)(.+?)[\"']")
IMG_NAME_IN_IMG_TAG_REGEX = re.compile(r"(?i)<img.*?src=[\"'](?!http://|https://)(.+?)[\"']")
# Regex to find the name of sound files inside a [sound] tag (specific to Anki)
# excluding the ones that start with http:// or https://
SOUND_NAME_IN_SOUND_TAG_REGEX = re.compile(r"\[sound:(?!http://|https://)(.+?)\]")
SOUND_NAME_IN_SOUND_TAG_REGEX = re.compile(r"(?i)\[sound:(?!http://|https://)(.+?)\]")
# Regex to find CSS import statements and url() references
CSS_IMPORT_REGEX = re.compile(r"(?i)(?:@import\s+[\"'](.+?)[\"'])")
# Regex to find CSS url() references
CSS_URL_REGEX = re.compile(r"(?i)(?:url\(\s*[\"']([^\"]+)[\"'])")


def local_media_names_from_html(html_content: str) -> Set[str]:
image_names = re.findall(IMG_NAME_IN_IMG_TAG_REGEX, html_content)
sound_names = re.findall(SOUND_NAME_IN_SOUND_TAG_REGEX, html_content)
all_names = set(image_names + sound_names)
css_import_names = re.findall(CSS_IMPORT_REGEX, html_content)
css_url_names = re.findall(CSS_URL_REGEX, html_content)
all_names = set(image_names + sound_names + css_import_names + css_url_names)
return all_names
7 changes: 5 additions & 2 deletions ankihub/gui/media_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .. import LOGGER
from ..addon_ankihub_client import AddonAnkiHubClient
from ..ankihub_client.models import DeckMedia
from ..ankihub_client.models import DeckMedia, get_media_names_from_notetype
from ..db import ankihub_db
from ..settings import config, get_anki_profile_id
from .operations import AddonQueryOp
Expand Down Expand Up @@ -187,15 +187,18 @@ def _media_referenced_by_notes(self, ah_did: uuid.UUID) -> Set[str]:
anki_nids: List[NoteId] = ankihub_db.anki_nids_for_ankihub_deck(ah_did)

media_names: Set[str] = set()
note_type_ids: Set[int] = set()
for nid in anki_nids:
try:
note = aqt.mw.col.get_note(nid)
except NotFoundError:
continue
note_type_ids.add(note.mid)
flds = "".join(note.fields)
# Extract media references using Anki's files_in_str (handles latex)
media_names.update(aqt.mw.col.media.files_in_str(note.mid, flds))

for note_type_id in note_type_ids:
media_names.update(get_media_names_from_notetype(note_type_id))
return media_names

def _missing_media_for_ah_deck(self, ah_did: uuid.UUID) -> List[str]:
Expand Down
129 changes: 129 additions & 0 deletions tests/addon/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,135 @@ def test_with_deck_name_existing(
create_ankihub_deck_mock.assert_not_called()


class TestMediaNameExtraction:
"""Tests for media name extraction from notes and note types."""

def test_extracts_media_from_note_fields(
self,
anki_session_with_addon_data: AnkiSession,
):
"""Media references in note fields (images and sounds) are extracted."""
from ankihub.ankihub_client.models import get_media_names_from_notes_data

with anki_session_with_addon_data.profile_loaded():
mw = anki_session_with_addon_data.mw
note_type = mw.col.models.by_name("Basic")
mid = note_type["id"]

notes_data = [
NoteInfoFactory.create(
mid=mid,
fields=[
Field(name="Front", value='<img src="image1.png">'),
Field(name="Back", value="[sound:audio1.mp3]"),
],
),
NoteInfoFactory.create(
mid=mid,
fields=[
Field(name="Front", value='<img src="image2.jpg"> and [sound:audio2.wav]'),
Field(name="Back", value="no media here"),
],
),
]

media_names = get_media_names_from_notes_data(notes_data)

assert media_names == {"image1.png", "audio1.mp3", "image2.jpg", "audio2.wav"}

def test_extracts_media_from_note_type_templates(
self,
anki_session_with_addon_data: AnkiSession,
):
"""Media references in note type templates (qfmt/afmt) are extracted."""
from ankihub.ankihub_client.models import get_media_names_from_notes_data

with anki_session_with_addon_data.profile_loaded():
mw = anki_session_with_addon_data.mw

note_type = note_type_with_field_names(["Front", "Back"])
note_type["tmpls"][0]["qfmt"] = '{{Front}}<img src="template_image.png">'
note_type["tmpls"][0]["afmt"] = "{{Back}}[sound:template_audio.mp3]"
mw.col.models.add_dict(note_type)
note_type = mw.col.models.by_name(note_type["name"])
mid = note_type["id"]

notes_data = [
NoteInfoFactory.create(
mid=mid,
fields=[
Field(name="Front", value="plain text"),
Field(name="Back", value="plain text"),
],
),
]

media_names = get_media_names_from_notes_data(notes_data)

assert "template_image.png" in media_names
assert "template_audio.mp3" in media_names

def test_extracts_media_from_note_type_css(
self,
anki_session_with_addon_data: AnkiSession,
):
"""Media references in note type CSS are extracted."""
from ankihub.ankihub_client.models import get_media_names_from_notes_data

with anki_session_with_addon_data.profile_loaded():
mw = anki_session_with_addon_data.mw

note_type = note_type_with_field_names(["Front", "Back"])
note_type["css"] = (
'@import "foo_import.css"; '
'.card { background: url("foo_double_quoted.png");'
" background: url('foo_single_quoted.png'); }"
)
mw.col.models.add_dict(note_type)
note_type = mw.col.models.by_name(note_type["name"])
mid = note_type["id"]

notes_data = [
NoteInfoFactory.create(
mid=mid,
fields=[
Field(name="Front", value="plain text"),
Field(name="Back", value="plain text"),
],
),
]

media_names = get_media_names_from_notes_data(notes_data)

assert media_names == {"foo_import.css", "foo_double_quoted.png", "foo_single_quoted.png"}

def test_excludes_remote_media_urls(
self,
anki_session_with_addon_data: AnkiSession,
):
"""Remote URLs (http/https) are excluded from media extraction."""
from ankihub.ankihub_client.models import get_media_names_from_notes_data

with anki_session_with_addon_data.profile_loaded():
mw = anki_session_with_addon_data.mw
note_type = mw.col.models.by_name("Basic")
mid = note_type["id"]

notes_data = [
NoteInfoFactory.create(
mid=mid,
fields=[
Field(name="Front", value='<img src="local.png"><img src="http://example.com/remote.png">'),
Field(name="Back", value='<img src="https://example.com/secure.jpg">[sound:local.mp3]'),
],
),
]

media_names = get_media_names_from_notes_data(notes_data)

assert media_names == {"local.png", "local.mp3"}


class TestGetReviewCountForAHDeckSince:
@pytest.mark.parametrize(
"review_deltas, since_time, expected_count",
Expand Down
11 changes: 7 additions & 4 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,7 @@ def test_zips_media_files_from_deck_notes(

deck_id = next_deterministic_uuid()
remove_mock = mocker.patch("os.remove")
self._upload_media_for_notes_data(client, notes_data, deck_id)
self._upload_media_for_notes_data(mocker, client, notes_data, deck_id)

# We will create and check for just one chunk in this test
path_to_created_zip_file = Path(TEST_MEDIA_PATH / f"{deck_id}_0_deck_assets_part.zip")
Expand Down Expand Up @@ -1647,7 +1647,7 @@ def test_uploads_generated_zipped_file(
"_upload_file_to_s3_with_reusable_presigned_url",
)

self._upload_media_for_notes_data(client, notes_data, deck_id)
self._upload_media_for_notes_data(mocker, client, notes_data, deck_id)

get_presigned_url_mock.assert_called_once_with(prefix=f"deck_assets/{deck_id}")
mocked_upload_file_to_s3.assert_called_once_with(
Expand All @@ -1667,13 +1667,16 @@ def test_removes_zipped_file_after_upload(
mocker.patch.object(client, "_upload_file_to_s3_with_reusable_presigned_url")

deck_id = next_deterministic_uuid()
self._upload_media_for_notes_data(client, notes_data, deck_id)
self._upload_media_for_notes_data(mocker, client, notes_data, deck_id)

path_to_created_zip_file = Path(TEST_MEDIA_PATH / f"{deck_id}.zip")

assert not path_to_created_zip_file.is_file()

def _upload_media_for_notes_data(self, client: AnkiHubClient, notes_data: List[NoteInfo], ah_did: uuid.UUID):
def _upload_media_for_notes_data(
self, mocker: MockerFixture, client: AnkiHubClient, notes_data: List[NoteInfo], ah_did: uuid.UUID
):
mocker.patch("ankihub.ankihub_client.models.get_media_names_from_notetype", return_value=set())
media_names = get_media_names_from_notes_data(notes_data)
media_paths = {TEST_MEDIA_PATH / media_name for media_name in media_names}
client.upload_media(media_paths, ah_did)
Expand Down
Loading