Skip to content

Commit

Permalink
#84: Fix #84 - an exception will no longer be thrown when saving tons…
Browse files Browse the repository at this point in the history
… of tags at once
  • Loading branch information
kamil-sita committed Jul 27, 2024
1 parent d8fe07a commit 362b251
Show file tree
Hide file tree
Showing 23 changed files with 377 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ private static TaskResult<Void> runTaskActual(ImportOldLabelleDataInput paramete
var result = inRepositoryService.addImage(parameter.repositoryId(), Paths.get(path).toFile());
if (result.isSuccess()) {
UUID imageId = result.getSuccess().getId();
PersistableImagesTags persistableImagesTags = new PersistableImagesTags(parameter.repositoryId());
PersistableImagesTags persistableImagesTags = new PersistableImagesTags();
imageCategory.imageCategoriesValues().forEach(imageCategoriesValue -> {
String category = categoryUuidCache.get(imageCategoriesValue.categoryUuid());
String baseTag = tagUuidCache.get(imageCategoriesValue.categoryValueUuid());
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/java/place/sita/labelle/core/CoreAppConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package place.sita.labelle.core;

import org.springframework.boot.context.properties.ConfigurationPropertiesScan;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
Expand All @@ -9,6 +10,7 @@

@Configuration
@ComponentScan
@ConfigurationPropertiesScan
@Import({ExtensionsConfig.class, MagicSchedulerConfig.class, CoreCommonAppConfig.class, CategoryBuilderConfig.class})
public class CoreAppConfig {
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ private static void copyImages(CloneRepositoryTaskInput parameter, TaskContext<R
imageIterator.forEachRemaining(image -> {
UUID imageId = copyImage(taskContext, newRepoId, image);

addTags(taskContext, newRepoId, image, imageId);
addTags(taskContext, image, imageId);
});
}
}

private static void addTags(TaskContext<RepositoryApi> taskContext, UUID newRepoId, ImageResponse image, UUID imageId) {
PersistableImagesTags persistableImagesTags = new PersistableImagesTags(newRepoId);
private static void addTags(TaskContext<RepositoryApi> taskContext, ImageResponse image, UUID imageId) {
PersistableImagesTags persistableImagesTags = new PersistableImagesTags();

taskContext.getApi()
.getInRepositoryService()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public TaskResult<UUID> runTask(CreateChildRepositoryTaskInput parameter, TaskCo

Map<String, UUID> referenceToNewImageId = createImagesInNewRepo(taskContext, parentImagesReferences, newRepoId);

addTagsToNewImages(taskContext, parentImagesReferences, newRepoId, referenceToNewImageId);
addTagsToNewImages(taskContext, parentImagesReferences, referenceToNewImageId);

return TaskResult.success(newRepoId);
}
Expand All @@ -53,10 +53,10 @@ private static Map<UUID, List<Tag>> getTagsOfParents(TaskContext<RepositoryApi>
return taskContext.getApi().getAcrossRepositoryService().getTags(allReferencedImages);
}

private static void addTagsToNewImages(TaskContext<RepositoryApi> taskContext, Map<String, List<UUID>> parentImagesReferences, UUID newRepoId, Map<String, UUID> referenceToNewImageId) {
private static void addTagsToNewImages(TaskContext<RepositoryApi> taskContext, Map<String, List<UUID>> parentImagesReferences, Map<String, UUID> referenceToNewImageId) {
Map<UUID, List<Tag>> tags = getTagsOfParents(taskContext, parentImagesReferences);

PersistableImagesTags persistableImagesTags = new PersistableImagesTags(newRepoId);
PersistableImagesTags persistableImagesTags = new PersistableImagesTags();

for (var newImage : referenceToNewImageId.entrySet()) {
String newImageReference = newImage.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public UUID duplicateImage(UUID selectedImageId) {
.values(newId, originalImage.value2(), originalImage.value3(), newReference.toString(), originalImage.value5(), originalImage.value6())
.execute();

PersistableImagesTags pit = new PersistableImagesTags(originalImage.value3());
PersistableImagesTags pit = new PersistableImagesTags();
getTags(selectedImageId).forEach(tag -> pit.addTag(newId, tag));
addTags(pit);

Expand Down Expand Up @@ -348,8 +348,8 @@ public List<Tag> getTags(UUID imageId) {
}

@Transactional
public void addTag(UUID imageId, @Nullable UUID repositoryId, Tag tag) {
tagRepository.addTag(imageId, repositoryId, tag);
public void addTag(UUID imageId, Tag tag) {
tagRepository.addTag(imageId, tag);
}

@Transactional
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
package place.sita.labelle.core.repository.inrepository.tags;

import java.util.*;
import java.util.stream.Collectors;

/**
* A structure to use if you want to quickly add a lot of tags to images and pass this info over to {@link TagRepository}.
*/
public class PersistableImagesTags {

private final UUID repositoryId;

private final Map<UUID, Set<Tag>> tags = new LinkedHashMap<>();

public PersistableImagesTags(UUID repositoryId) {
this.repositoryId = repositoryId;
}

public PersistableImagesTags() {
this.repositoryId = null;
}

public void addTag(UUID imageId, String category, String tag) {
addTag(imageId, new Tag(category, tag));
}
Expand All @@ -28,40 +17,12 @@ public void addTag(UUID imageId, Tag tag) {
tags.computeIfAbsent(imageId, k -> new LinkedHashSet<>()).add(tag);
}

public UUID repoId() {
return repositoryId;
}

public Set<UUID> images() {
return tags.keySet();
}

public Set<String> categories() {
Set<String> families = new LinkedHashSet<>();
for (var entry : tags.entrySet()) {
for (var tagValue : entry.getValue()) {
families.add(tagValue.category());
}
}
return families;
}

public Set<Tag> tags() {
Set<Tag> tagViews = new LinkedHashSet<>();
for (var entry : tags.entrySet()) {
tagViews.addAll(entry.getValue());
}
return tagViews;
}

public Set<ImageTag> imageTags() {
return tags.entrySet().stream()
.flatMap(entry -> entry.getValue().stream().map(tagValue -> new ImageTag(entry.getKey(), tagValue.category(), tagValue.tag())))
.collect(Collectors.toSet());
}

public record ImageTag(UUID imageId, String category, String tag) {

public Set<Tag> tags(UUID image) {
return tags.getOrDefault(image, Collections.emptySet());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import javax.annotation.Nullable;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import static org.jooq.impl.DSL.row;
import static place.sita.labelle.jooq.Tables.TAG;
Expand All @@ -18,9 +20,11 @@ public class TagRepository {
// todo this TagRepository assumes that *something* will clean up tags, families after they are no longer needed. Write vacuuming process

private final DSLContext dslContext;
private final TagRepositoryProperties tagRepositoryProperties;

public TagRepository(DSLContext dslContext) {
public TagRepository(DSLContext dslContext, TagRepositoryProperties tagRepositoryProperties) {
this.dslContext = dslContext;
this.tagRepositoryProperties = tagRepositoryProperties;
}

@Transactional
Expand All @@ -30,18 +34,41 @@ public void addTags(PersistableImagesTags persistableImagesTags) {
return;
}

UUID anyImage = images.iterator().next();
Map<UUID, List<UUID>> reposImages = getReposImages(images);

// if they're not in the same repository that's like the weirdest usage of this API ever, but might be worth to check if that's the case TODO
UUID actualRepositoryId = resolveRepositoryId(anyImage, persistableImagesTags.repoId());
for (var entry : reposImages.entrySet()) {
applyChangesToRepo(entry.getKey(), entry.getValue(), persistableImagesTags);
}
}

Set<String> uniqueFamilies = persistableImagesTags.categories();
Map<String, UUID> categoryIds = getOrCreateCategoryIds(actualRepositoryId, uniqueFamilies);
private void applyChangesToRepo(UUID repoId, List<UUID> imagesIds, PersistableImagesTags persistableImagesTags) {
Set<String> uniqueCategories = new HashSet<>();
Set<Tag> uniqueTags = new HashSet<>();
for (var image : imagesIds) {
for (var change : persistableImagesTags.tags(image)) {
uniqueCategories.add(change.category());
uniqueTags.add(new Tag(change.category(), change.tag()));
}
}

Set<Tag> uniqueTags = persistableImagesTags.tags();
Map<String, UUID> categoryIds = getOrCreateCategoryIds(repoId, uniqueCategories);
Map<Tag, UUID> tagIds = getOrCreateTagIds(categoryIds, uniqueTags);

assignTagsToImages(persistableImagesTags, tagIds);
assignTagsToImages(persistableImagesTags, tagIds, imagesIds);
}

private Map<UUID, List<UUID>> getReposImages(Set<UUID> images) {
var results = dslContext.select(Tables.IMAGE.REPOSITORY_ID, Tables.IMAGE.ID)
.from(Tables.IMAGE)
.where(Tables.IMAGE.ID.in(images))
.fetch();

Map<UUID, List<UUID>> reposImages = new HashMap<>();
for (var result : results) {
reposImages.computeIfAbsent(result.value1(), k -> new ArrayList<>()).add(result.value2());
}

return reposImages;
}

private Map<String, UUID> getOrCreateCategoryIds(UUID actualRepositoryId, Set<String> uniqueCategories) {
Expand Down Expand Up @@ -129,14 +156,36 @@ record TagViewId(String value, UUID categoryId) { }
return tagIds;
}

private void assignTagsToImages(PersistableImagesTags persistableImagesTags, Map<Tag, UUID> tagIds) {
private void assignTagsToImages(PersistableImagesTags persistableImagesTags, Map<Tag, UUID> tagIds, List<UUID> imagesIdsScope) {
int i = 0;
while (i < imagesIdsScope.size()) {
List<UUID> batchOfImages = imagesIdsScope.subList(i, Math.min(i + tagRepositoryProperties.getImageBulkSize(), imagesIdsScope.size()));
assignTagsToImagesBatch(persistableImagesTags, tagIds, batchOfImages);
i += tagRepositoryProperties.getImageBulkSize();
}
}

private record ImageTag(UUID imageId, String category, String tag) {

}

private void assignTagsToImagesBatch(PersistableImagesTags persistableImagesTags, Map<Tag, UUID> tagIds, List<UUID> batchOfImages) {
Map<UUID, Set<Tag>> existingTags = new HashMap<>();

List<ImageTag> imageTags = batchOfImages.stream()
.mapMulti(new BiConsumer<UUID, Consumer<ImageTag>>() {
@Override
public void accept(UUID uuid, Consumer<ImageTag> consumer) {
persistableImagesTags.tags(uuid).forEach(tag -> consumer.accept(new ImageTag(uuid, tag.category(), tag.tag())));
}
})
.toList();

dslContext.select(Tables.IMAGE_TAGS.IMAGE_ID, Tables.IMAGE_TAGS.TAG, Tables.IMAGE_TAGS.TAG_CATEGORY)
.from(Tables.IMAGE_TAGS)
.where(
row(Tables.IMAGE_TAGS.IMAGE_ID, Tables.IMAGE_TAGS.TAG, Tables.IMAGE_TAGS.TAG_CATEGORY)
.in(persistableImagesTags.imageTags().stream().map(imageTag -> row(imageTag.imageId(), imageTag.tag(), imageTag.category())).toList())
.in(imageTags.stream().map(imageTag -> row(imageTag.imageId(), imageTag.tag(), imageTag.category())).toList())
)
.fetch()
.forEach(rr -> {
Expand All @@ -150,7 +199,7 @@ private void assignTagsToImages(PersistableImagesTags persistableImagesTags, Map
.columns(Tables.TAG_IMAGE.TAG_ID, Tables.TAG_IMAGE.IMAGE_ID);

int toPersist = 0;
for (var tag : persistableImagesTags.imageTags()) {
for (var tag : imageTags) {
UUID imageId = tag.imageId();

if (existingTags.getOrDefault(imageId, Collections.emptySet()).contains(new Tag(tag.tag(), tag.category()))) {
Expand All @@ -160,17 +209,28 @@ private void assignTagsToImages(PersistableImagesTags persistableImagesTags, Map
UUID tagId = tagIds.get(new Tag(tag.category(), tag.tag()));
ongoing = ongoing.values(tagId, imageId);
toPersist++;
if (toPersist > tagRepositoryProperties.getTagBulkSize()) {
int c = ongoing.execute();
if (c != toPersist) {
throw new RuntimeException();
}
toPersist = 0;
ongoing = dslContext.insertInto(Tables.TAG_IMAGE)
.columns(Tables.TAG_IMAGE.TAG_ID, Tables.TAG_IMAGE.IMAGE_ID);
}
}

int c = ongoing.execute();
if (c != toPersist) {
throw new RuntimeException();
if (toPersist > 0) {
int c = ongoing.execute();
if (c != toPersist) {
throw new RuntimeException();
}
}
}

@Transactional
public void addTag(UUID imageId, @Nullable UUID repositoryId, Tag tag) {
PersistableImagesTags persistableImagesTags = new PersistableImagesTags(repositoryId);
public void addTag(UUID imageId, Tag tag) {
PersistableImagesTags persistableImagesTags = new PersistableImagesTags();
persistableImagesTags.addTag(imageId, tag);
addTags(persistableImagesTags);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package place.sita.labelle.core.repository.inrepository.tags;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@Component
@ConfigurationProperties(prefix = "place.sita.magic.scheduler")
public class TagRepositoryProperties {

private int tagBulkSize = 100;
private int imageBulkSize = 20;

public int getTagBulkSize() {
return tagBulkSize;
}

public void setTagBulkSize(int tagBulkSize) {
this.tagBulkSize = tagBulkSize;
}

public int getImageBulkSize() {
return imageBulkSize;
}

public void setImageBulkSize(int imageBulkSize) {
this.imageBulkSize = imageBulkSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public TaskResult<Response> runTask(Config parameter, TaskContext<RepositoryApi>
List<String> tags = getTagsInPath(file.getPath(), directory);
for (String tag : tags) {
tagsToImages.get(img.id()).add(tag);
taskContext.getApi().getInRepositoryService().addTag(img.id(), parameter.repositoryId, new Tag("folder", tag));
taskContext.getApi().getInRepositoryService().addTag(img.id(), new Tag("folder", tag));
}
} else {
taskContext.log("Failed to add image: " + file.getAbsolutePath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public TaskResult<Response> runTask(Config parameter, TaskContext<RepositoryApi>

Set<UUID> image = new HashSet<>();
InRepositoryService inRepositoryService = taskContext.getApi().getInRepositoryService();
PersistableImagesTags persistableImagesTags = new PersistableImagesTags(parameter.repositoryId);
PersistableImagesTags persistableImagesTags = new PersistableImagesTags();

try (CloseableIterator<ImageResponse> images = inRepositoryService.images().process().filterByRepository(parameter.repositoryId()).getIterator()) {
images.forEachRemaining(ir -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ private List<String> applyTags(RepositoryApi repositoryApi, Response resp, UUID
List<String> tags = new ArrayList<>();
if (tagger.equals("clip")) {
ctx.log("Got description: " + resp.caption);
repositoryApi.getInRepositoryService().addTag(imageId, null, new Tag(tagger, resp.caption));
repositoryApi.getInRepositoryService().addTag(imageId, new Tag(tagger, resp.caption));
tags.add(resp.caption);
} else {
Arrays.stream(resp.caption.split(", ")).forEach(tag -> {
ctx.log("Adding a tag: " + tag);
repositoryApi.getInRepositoryService().addTag(imageId, null, new Tag(tagger, tag));
repositoryApi.getInRepositoryService().addTag(imageId, new Tag(tagger, tag));
tags.add(tag);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ public void shouldClearRepository() {
// given
Repository repo = repositoryService.addRepository("Test repo");
UUID imageId = inRepositoryService.addEmptySyntheticImage(repo.id());
inRepositoryService.addTag(imageId, null, new Tag("Some category", "Some tag"));
inRepositoryService.addTag(imageId, new Tag("Some category", "Some tag"));
UUID anotherImageId = inRepositoryService.addEmptySyntheticImage(repo.id());
inRepositoryService.addTag(anotherImageId, null, new Tag("Some category", "Some tag"));
inRepositoryService.addTag(anotherImageId, null, new Tag("Some category 2", "Some tag"));
inRepositoryService.addTag(anotherImageId, new Tag("Some category", "Some tag"));
inRepositoryService.addTag(anotherImageId, new Tag("Some category 2", "Some tag"));

// when
taskExecutionEnvironment.executeTask(
Expand All @@ -71,10 +71,10 @@ public void shouldNotClearUnrelatedRepository() {
// given
Repository repo = repositoryService.addRepository("Test repo");
UUID imageId = inRepositoryService.addEmptySyntheticImage(repo.id());
inRepositoryService.addTag(imageId, null, new Tag("Some category", "Some tag"));
inRepositoryService.addTag(imageId, new Tag("Some category", "Some tag"));
UUID anotherImageId = inRepositoryService.addEmptySyntheticImage(repo.id());
inRepositoryService.addTag(anotherImageId, null, new Tag("Some category", "Some tag"));
inRepositoryService.addTag(anotherImageId, null, new Tag("Some category 2", "Some tag"));
inRepositoryService.addTag(anotherImageId, new Tag("Some category", "Some tag"));
inRepositoryService.addTag(anotherImageId, new Tag("Some category 2", "Some tag"));

Repository unrelatedRepo = repositoryService.addRepository("Unrelated test repo");
inRepositoryService.addEmptySyntheticImage(unrelatedRepo.id());
Expand Down
Loading

0 comments on commit 362b251

Please sign in to comment.