diff --git a/ankihub/ankihub_client/models.py b/ankihub/ankihub_client/models.py index 38678df38..6617367be 100644 --- a/ankihub/ankihub_client/models.py +++ b/ankihub/ankihub_client/models.py @@ -7,9 +7,8 @@ from dataclasses import dataclass from datetime import date, datetime from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Set +from typing import Any, Callable, Dict, List, Optional, Sequence, Set -from anki.models import NotetypeId from mashumaro import field_options from mashumaro.config import BaseConfig from mashumaro.mixins.json import DataClassJSONMixin @@ -317,12 +316,19 @@ class DailyCardReviewSummary(DataClassJSONMixinWithConfig): # Media related functions -def get_media_names_from_notes_data(notes_data: Sequence[NoteInfo]) -> Set[str]: +def get_media_names_from_notes_data( + notes_data: Sequence[NoteInfo], + get_notetype_by_id: Callable[[int], Dict[str, Any]], +) -> Set[str]: """Return the names of all media files on the given notes and their note types. Only returns names of local files, not remote files.""" note_refs = {name for note in notes_data for name in get_media_names_from_note_info(note)} note_type_ids = {note.mid for note in notes_data} - note_type_refs = {name for note_type_id in note_type_ids for name in get_media_names_from_notetype(note_type_id)} + note_type_refs = { + name + for note_type_id in note_type_ids + for name in get_media_names_from_notetype(get_notetype_by_id(note_type_id)) + } return {*note_refs, *note_type_refs} @@ -342,11 +348,8 @@ def get_media_names_from_note_info(note_info: NoteInfo) -> Set[str]: return result -def get_media_names_from_notetype(notetype_id: int) -> Set[str]: - import aqt - - refs = set() - note_type = aqt.mw.col.models.get(NotetypeId(notetype_id)) +def get_media_names_from_notetype(note_type: Dict[str, Any]) -> Set[str]: + refs: Set[str] = set() refs.update(local_media_names_from_html(note_type["css"])) for tmpl in note_type["tmpls"]: refs.update(local_media_names_from_html(tmpl["qfmt"])) diff --git a/ankihub/gui/decks_dialog.py b/ankihub/gui/decks_dialog.py index b235fa24c..e7bb2cc38 100644 --- a/ankihub/gui/decks_dialog.py +++ b/ankihub/gui/decks_dialog.py @@ -651,7 +651,7 @@ def on_note_type_selected( note_type = aqt.mw.col.models.by_name(note_type_selector.name) new_note_type = add_note_type(ah_did, note_type) - media_names = get_media_names_from_notetype(new_note_type["id"]) + media_names = get_media_names_from_notetype(new_note_type) if media_names: media_sync.start_media_upload(media_names, ah_did) @@ -750,7 +750,7 @@ def on_note_type_selected(note_type_selector: SearchableSelectionDialog, MODEL_N ah_did = self._selected_ah_did() update_note_type_templates_and_styles(ah_did, note_type) - media_names = get_media_names_from_notetype(note_type["id"]) + media_names = get_media_names_from_notetype(note_type) if media_names: media_sync.start_media_upload(media_names, ah_did) diff --git a/ankihub/gui/media_sync.py b/ankihub/gui/media_sync.py index f9b473281..f3f42eb4c 100644 --- a/ankihub/gui/media_sync.py +++ b/ankihub/gui/media_sync.py @@ -8,6 +8,7 @@ import aqt from anki.errors import NotFoundError +from anki.models import NotetypeId from anki.notes import NoteId from aqt.qt import QAction @@ -198,7 +199,8 @@ def _media_referenced_by_notes(self, ah_did: uuid.UUID) -> Set[str]: # Extract media references using Anki's files_in_str (handles latex) media_names.update(aqt.mw.col.media.files_in_str(note.mid, flds)) for note_type_id in note_type_ids: - media_names.update(get_media_names_from_notetype(note_type_id)) + note_type = aqt.mw.col.models.get(NotetypeId(note_type_id)) + media_names.update(get_media_names_from_notetype(note_type)) return media_names def _missing_media_for_ah_deck(self, ah_did: uuid.UUID) -> List[str]: diff --git a/ankihub/gui/operations/deck_creation.py b/ankihub/gui/operations/deck_creation.py index cd453ed3e..b58aecd77 100644 --- a/ankihub/gui/operations/deck_creation.py +++ b/ankihub/gui/operations/deck_creation.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone import aqt +from anki.models import NotetypeId from aqt import QCheckBox, QMessageBox from aqt.studydeck import StudyDeck from aqt.utils import showInfo, tooltip @@ -116,7 +117,10 @@ def on_success(deck_creation_result: DeckCreationResult) -> None: # Upload all existing local media for this deck # (media files that are referenced on Deck's notes) if should_upload_media: - media_names = get_media_names_from_notes_data(deck_creation_result.notes_data) + media_names = get_media_names_from_notes_data( + deck_creation_result.notes_data, + lambda mid: aqt.mw.col.models.get(NotetypeId(mid)), + ) media_sync.start_media_upload(media_names, deck_creation_result.ankihub_did) # Add the deck to the list of decks the user owns diff --git a/ankihub/main/suggestions.py b/ankihub/main/suggestions.py index 5f2ad9df5..f8d6fff9a 100644 --- a/ankihub/main/suggestions.py +++ b/ankihub/main/suggestions.py @@ -23,6 +23,7 @@ ) import aqt +from anki.models import NotetypeId from anki.notes import Note, NoteId from ..addon_ankihub_client import AddonAnkiHubClient as AnkiHubClient @@ -410,7 +411,9 @@ def _rename_and_upload_media_for_suggestions( original_notes_data = [ note_info for suggestion in suggestions if (note_info := ankihub_db.note_data(NoteId(suggestion.anki_nid))) ] - original_media_names: Set[str] = get_media_names_from_notes_data(original_notes_data) + original_media_names: Set[str] = get_media_names_from_notes_data( + original_notes_data, lambda mid: aqt.mw.col.models.get(NotetypeId(mid)) + ) suggestion_media_names: Set[str] = get_media_names_from_suggestions(suggestions) # Filter out unchanged media file names so we don't hash and upload media files that aren't part of the suggestion diff --git a/tests/addon/test_unit.py b/tests/addon/test_unit.py index e0753303b..5c7411fdc 100644 --- a/tests/addon/test_unit.py +++ b/tests/addon/test_unit.py @@ -16,7 +16,7 @@ import pytest import requests from anki.decks import DeckId -from anki.models import NotetypeDict +from anki.models import NotetypeDict, NotetypeId from anki.notes import Note, NoteId from approvaltests.approvals import verify # type: ignore from approvaltests.namer import NamerFactory # type: ignore @@ -2065,7 +2065,8 @@ def test_basic( # Assert that the correct functions were called. create_ankihub_deck_mock.assert_called_once_with(deck_name, private=False, add_subdeck_tags=False) - get_media_names_from_notes_data_mock.assert_called_once_with(notes_data) + get_media_names_from_notes_data_mock.assert_called_once() + assert get_media_names_from_notes_data_mock.call_args[0][0] == notes_data start_media_upload_mock.assert_called_once() def test_with_deck_name_existing( @@ -2135,7 +2136,7 @@ def test_extracts_media_from_note_fields( ), ] - media_names = get_media_names_from_notes_data(notes_data) + media_names = get_media_names_from_notes_data(notes_data, lambda mid: mw.col.models.get(NotetypeId(mid))) assert media_names == {"image1.png", "audio1.mp3", "image2.jpg", "audio2.wav"} @@ -2166,7 +2167,7 @@ def test_extracts_media_from_note_type_templates( ), ] - media_names = get_media_names_from_notes_data(notes_data) + media_names = get_media_names_from_notes_data(notes_data, lambda mid: mw.col.models.get(NotetypeId(mid))) assert "template_image.png" in media_names assert "template_audio.mp3" in media_names @@ -2201,7 +2202,7 @@ def test_extracts_media_from_note_type_css( ), ] - media_names = get_media_names_from_notes_data(notes_data) + media_names = get_media_names_from_notes_data(notes_data, lambda mid: mw.col.models.get(NotetypeId(mid))) assert media_names == {"foo_import.css", "foo_double_quoted.png", "foo_single_quoted.png"} @@ -2227,7 +2228,7 @@ def test_excludes_remote_media_urls( ), ] - media_names = get_media_names_from_notes_data(notes_data) + media_names = get_media_names_from_notes_data(notes_data, lambda mid: mw.col.models.get(NotetypeId(mid))) assert media_names == {"local.png", "local.mp3"} diff --git a/tests/client/test_client.py b/tests/client/test_client.py index b8c95cc7d..be1f103b8 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -11,7 +11,7 @@ from copy import deepcopy from datetime import date, datetime, timedelta, timezone from pathlib import Path -from typing import Callable, Generator, List, Optional, cast +from typing import Any, Callable, Dict, Generator, List, Optional, cast from unittest.mock import Mock import pytest @@ -1606,7 +1606,7 @@ def test_zips_media_files_from_deck_notes( # We will create and check for just one chunk in this test path_to_created_zip_file = Path(TEST_MEDIA_PATH / f"{deck_id}_0_deck_assets_part.zip") - all_media_names_in_notes = get_media_names_from_notes_data(notes_data) + all_media_names_in_notes = get_media_names_from_notes_data(notes_data, lambda mid: self._empty_notetype()) assert path_to_created_zip_file.is_file() assert len(all_media_names_in_notes) == 14 with zipfile.ZipFile(path_to_created_zip_file, "r") as zip_ref: @@ -1673,11 +1673,14 @@ def test_removes_zipped_file_after_upload( assert not path_to_created_zip_file.is_file() + @staticmethod + def _empty_notetype() -> Dict[str, Any]: + return {"css": "", "tmpls": []} + def _upload_media_for_notes_data( self, mocker: MockerFixture, client: AnkiHubClient, notes_data: List[NoteInfo], ah_did: uuid.UUID ): - mocker.patch("ankihub.ankihub_client.models.get_media_names_from_notetype", return_value=set()) - media_names = get_media_names_from_notes_data(notes_data) + media_names = get_media_names_from_notes_data(notes_data, lambda mid: self._empty_notetype()) media_paths = {TEST_MEDIA_PATH / media_name for media_name in media_names} client.upload_media(media_paths, ah_did)