Skip to content

Commit

Permalink
Add more strict parameter for GeolocationDbUpdater
Browse files Browse the repository at this point in the history
  • Loading branch information
acelaya committed Dec 11, 2024
1 parent 06c0a94 commit b8ac9f3
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 63 deletions.
36 changes: 22 additions & 14 deletions module/CLI/src/Command/Visit/DownloadGeoLiteDbCommand.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use Shlinkio\Shlink\CLI\Util\ExitCode;
use Shlinkio\Shlink\Core\Exception\GeolocationDbUpdateFailedException;
use Shlinkio\Shlink\Core\Geolocation\GeolocationDbUpdaterInterface;
use Shlinkio\Shlink\Core\Geolocation\GeolocationDownloadProgressHandlerInterface;
use Shlinkio\Shlink\Core\Geolocation\GeolocationResult;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Helper\ProgressBar;
Expand All @@ -16,13 +17,14 @@

use function sprintf;

class DownloadGeoLiteDbCommand extends Command
class DownloadGeoLiteDbCommand extends Command implements GeolocationDownloadProgressHandlerInterface
{
public const string NAME = 'visit:download-db';

private ProgressBar|null $progressBar = null;
private SymfonyStyle $io;

public function __construct(private GeolocationDbUpdaterInterface $dbUpdater)
public function __construct(private readonly GeolocationDbUpdaterInterface $dbUpdater)
{
parent::__construct();
}
Expand All @@ -39,32 +41,26 @@ protected function configure(): void

protected function execute(InputInterface $input, OutputInterface $output): int
{
$io = new SymfonyStyle($input, $output);
$this->io = new SymfonyStyle($input, $output);

try {
$result = $this->dbUpdater->checkDbUpdate(function (bool $olderDbExists) use ($io): void {
$io->text(sprintf('<fg=blue>%s GeoLite2 db file...</>', $olderDbExists ? 'Updating' : 'Downloading'));
$this->progressBar = new ProgressBar($io);
}, function (int $total, int $downloaded): void {
$this->progressBar?->setMaxSteps($total);
$this->progressBar?->setProgress($downloaded);
});
$result = $this->dbUpdater->checkDbUpdate($this);

if ($result === GeolocationResult::LICENSE_MISSING) {
$io->warning('It was not possible to download GeoLite2 db, because a license was not provided.');
$this->io->warning('It was not possible to download GeoLite2 db, because a license was not provided.');
return ExitCode::EXIT_WARNING;
}

if ($this->progressBar === null) {
$io->info('GeoLite2 db file is up to date.');
$this->io->info('GeoLite2 db file is up to date.');
} else {
$this->progressBar->finish();
$io->success('GeoLite2 db file properly downloaded.');
$this->io->success('GeoLite2 db file properly downloaded.');
}

return ExitCode::EXIT_SUCCESS;
} catch (GeolocationDbUpdateFailedException $e) {
return $this->processGeoLiteUpdateError($e, $io);
return $this->processGeoLiteUpdateError($e, $this->io);
}
}

Expand All @@ -86,4 +82,16 @@ private function processGeoLiteUpdateError(GeolocationDbUpdateFailedException $e

return $olderDbExists ? ExitCode::EXIT_WARNING : ExitCode::EXIT_FAILURE;
}

public function beforeDownload(bool $olderDbExists): void
{
$this->io->text(sprintf('<fg=blue>%s GeoLite2 db file...</>', $olderDbExists ? 'Updating' : 'Downloading'));
$this->progressBar = new ProgressBar($this->io);
}

public function handleProgress(int $total, int $downloaded, bool $olderDbExists): void
{
$this->progressBar?->setMaxSteps($total);
$this->progressBar?->setProgress($downloaded);
}
}
34 changes: 24 additions & 10 deletions module/CLI/test/GeoLite/GeolocationDbUpdaterTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
use Shlinkio\Shlink\Core\Config\Options\TrackingOptions;
use Shlinkio\Shlink\Core\Exception\GeolocationDbUpdateFailedException;
use Shlinkio\Shlink\Core\Geolocation\GeolocationDbUpdater;
use Shlinkio\Shlink\Core\Geolocation\GeolocationDownloadProgressHandlerInterface;
use Shlinkio\Shlink\Core\Geolocation\GeolocationResult;
use Shlinkio\Shlink\IpGeolocation\Exception\DbUpdateException;
use Shlinkio\Shlink\IpGeolocation\Exception\MissingLicenseException;
Expand All @@ -29,13 +30,32 @@ class GeolocationDbUpdaterTest extends TestCase
private MockObject & DbUpdaterInterface $dbUpdater;
private MockObject & Reader $geoLiteDbReader;
private MockObject & Lock\LockInterface $lock;
/** @var GeolocationDownloadProgressHandlerInterface&object{beforeDownloadCalled: bool, handleProgressCalled: bool} */
private GeolocationDownloadProgressHandlerInterface $progressHandler;

protected function setUp(): void
{
$this->dbUpdater = $this->createMock(DbUpdaterInterface::class);
$this->geoLiteDbReader = $this->createMock(Reader::class);
$this->lock = $this->createMock(Lock\SharedLockInterface::class);
$this->lock->method('acquire')->with($this->isTrue())->willReturn(true);
$this->progressHandler = new class implements GeolocationDownloadProgressHandlerInterface {
public function __construct(
public bool $beforeDownloadCalled = false,
public bool $handleProgressCalled = false,
) {
}

public function beforeDownload(bool $olderDbExists): void
{
$this->beforeDownloadCalled = true;
}

public function handleProgress(int $total, int $downloaded, bool $olderDbExists): void
{
$this->handleProgressCalled = true;
}
};
}

#[Test]
Expand All @@ -47,12 +67,9 @@ public function properResultIsReturnedWhenLicenseIsMissing(): void
);
$this->geoLiteDbReader->expects($this->never())->method('metadata');

$isCalled = false;
$result = $this->geolocationDbUpdater()->checkDbUpdate(function () use (&$isCalled): void {
$isCalled = true;
});
$result = $this->geolocationDbUpdater()->checkDbUpdate($this->progressHandler);

self::assertTrue($isCalled);
self::assertTrue($this->progressHandler->beforeDownloadCalled);
self::assertEquals(GeolocationResult::LICENSE_MISSING, $result);
}

Expand All @@ -67,17 +84,14 @@ public function exceptionIsThrownWhenOlderDbDoesNotExistAndDownloadFails(): void
)->willThrowException($prev);
$this->geoLiteDbReader->expects($this->never())->method('metadata');

$isCalled = false;
try {
$this->geolocationDbUpdater()->checkDbUpdate(function () use (&$isCalled): void {
$isCalled = true;
});
$this->geolocationDbUpdater()->checkDbUpdate($this->progressHandler);
self::fail();
} catch (Throwable $e) {
self::assertInstanceOf(GeolocationDbUpdateFailedException::class, $e);
self::assertSame($prev, $e->getPrevious());
self::assertFalse($e->olderDbExists());
self::assertTrue($isCalled);
self::assertTrue($this->progressHandler->beforeDownloadCalled);
}
}

Expand Down
41 changes: 28 additions & 13 deletions module/Core/src/EventDispatcher/UpdateGeoLiteDb.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use Psr\Log\LoggerInterface;
use Shlinkio\Shlink\Core\EventDispatcher\Event\GeoLiteDbCreated;
use Shlinkio\Shlink\Core\Geolocation\GeolocationDbUpdaterInterface;
use Shlinkio\Shlink\Core\Geolocation\GeolocationDownloadProgressHandlerInterface;
use Shlinkio\Shlink\Core\Geolocation\GeolocationResult;
use Throwable;

Expand All @@ -24,21 +25,35 @@ public function __construct(

public function __invoke(): void
{
$beforeDownload = fn (bool $olderDbExists) => $this->logger->notice(
sprintf('%s GeoLite2 db file...', $olderDbExists ? 'Updating' : 'Downloading'),
);
$messageLogged = false;
$handleProgress = function (int $total, int $downloaded, bool $olderDbExists) use (&$messageLogged): void {
if ($messageLogged || $total > $downloaded) {
return;
}
try {
$result = $this->dbUpdater->checkDbUpdate(
new class ($this->logger) implements GeolocationDownloadProgressHandlerInterface {
public function __construct(
private readonly LoggerInterface $logger,
private bool $messageLogged = false,
) {
}

$messageLogged = true;
$this->logger->notice(sprintf('Finished %s GeoLite2 db file', $olderDbExists ? 'updating' : 'downloading'));
};
public function beforeDownload(bool $olderDbExists): void
{
$this->logger->notice(
sprintf('%s GeoLite2 db file...', $olderDbExists ? 'Updating' : 'Downloading'),
);
}

try {
$result = $this->dbUpdater->checkDbUpdate($beforeDownload, $handleProgress);
public function handleProgress(int $total, int $downloaded, bool $olderDbExists): void
{
if ($this->messageLogged || $total > $downloaded) {
return;
}

$this->messageLogged = true;
$this->logger->notice(
sprintf('Finished %s GeoLite2 db file', $olderDbExists ? 'updating' : 'downloading'),
);
}
},
);
if ($result === GeolocationResult::DB_CREATED) {
$this->eventDispatcher->dispatch(new GeoLiteDbCreated());
}
Expand Down
38 changes: 14 additions & 24 deletions module/Core/src/Geolocation/GeolocationDbUpdater.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
use Shlinkio\Shlink\Core\Exception\GeolocationDbUpdateFailedException;
use Shlinkio\Shlink\IpGeolocation\Exception\DbUpdateException;
use Shlinkio\Shlink\IpGeolocation\Exception\MissingLicenseException;
use Shlinkio\Shlink\IpGeolocation\Exception\WrongIpException;
use Shlinkio\Shlink\IpGeolocation\GeoLite2\DbUpdaterInterface;
use Symfony\Component\Lock\LockFactory;

Expand Down Expand Up @@ -41,8 +40,7 @@ public function __construct(
* @throws GeolocationDbUpdateFailedException
*/
public function checkDbUpdate(
callable|null $beforeDownload = null,
callable|null $handleProgress = null,
GeolocationDownloadProgressHandlerInterface|null $downloadProgressHandler = null,
): GeolocationResult {
if (! $this->trackingOptions->isGeolocationRelevant()) {
return GeolocationResult::CHECK_SKIPPED;
Expand All @@ -52,7 +50,7 @@ public function checkDbUpdate(
$lock->acquire(true); // Block until lock is released

try {
return $this->downloadIfNeeded($beforeDownload, $handleProgress);
return $this->downloadIfNeeded($downloadProgressHandler);
} finally {
$lock->release();
}
Expand All @@ -61,15 +59,16 @@ public function checkDbUpdate(
/**
* @throws GeolocationDbUpdateFailedException
*/
private function downloadIfNeeded(callable|null $beforeDownload, callable|null $handleProgress): GeolocationResult
{
private function downloadIfNeeded(
GeolocationDownloadProgressHandlerInterface|null $downloadProgressHandler,
): GeolocationResult {
if (! $this->dbUpdater->databaseFileExists()) {
return $this->downloadNewDb($beforeDownload, $handleProgress, olderDbExists: false);
return $this->downloadNewDb($downloadProgressHandler, olderDbExists: false);
}

$meta = ($this->geoLiteDbReaderFactory)()->metadata();
if ($this->buildIsTooOld($meta)) {
return $this->downloadNewDb($beforeDownload, $handleProgress, olderDbExists: true);
return $this->downloadNewDb($downloadProgressHandler, olderDbExists: true);
}

return GeolocationResult::DB_IS_UP_TO_DATE;
Expand Down Expand Up @@ -106,32 +105,23 @@ private function resolveBuildTimestamp(Metadata $meta): int
* @throws GeolocationDbUpdateFailedException
*/
private function downloadNewDb(
callable|null $beforeDownload,
callable|null $handleProgress,
GeolocationDownloadProgressHandlerInterface|null $downloadProgressHandler,
bool $olderDbExists,
): GeolocationResult {
if ($beforeDownload !== null) {
$beforeDownload($olderDbExists);
}
$downloadProgressHandler?->beforeDownload($olderDbExists);

try {
$this->dbUpdater->downloadFreshCopy($this->wrapHandleProgressCallback($handleProgress, $olderDbExists));
$this->dbUpdater->downloadFreshCopy(
static fn (int $total, int $downloaded)
=> $downloadProgressHandler?->handleProgress($total, $downloaded, $olderDbExists),
);
return $olderDbExists ? GeolocationResult::DB_UPDATED : GeolocationResult::DB_CREATED;
} catch (MissingLicenseException) {
return GeolocationResult::LICENSE_MISSING;
} catch (DbUpdateException | WrongIpException $e) {
} catch (DbUpdateException $e) {
throw $olderDbExists
? GeolocationDbUpdateFailedException::withOlderDb($e)
: GeolocationDbUpdateFailedException::withoutOlderDb($e);
}
}

private function wrapHandleProgressCallback(callable|null $handleProgress, bool $olderDbExists): callable|null
{
if ($handleProgress === null) {
return null;
}

return static fn (int $total, int $downloaded) => $handleProgress($total, $downloaded, $olderDbExists);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ interface GeolocationDbUpdaterInterface
* @throws GeolocationDbUpdateFailedException
*/
public function checkDbUpdate(
callable|null $beforeDownload = null,
callable|null $handleProgress = null,
GeolocationDownloadProgressHandlerInterface|null $downloadProgressHandler = null,
): GeolocationResult;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
<?php

declare(strict_types=1);

namespace Shlinkio\Shlink\Core\Geolocation;

interface GeolocationDownloadProgressHandlerInterface
{
/**
* Invoked right before starting to download a geolocation DB file, and only if it needs to be downloaded.
* @param $olderDbExists - Indicates if an older DB file already exists when this method is called
*/
public function beforeDownload(bool $olderDbExists): void;

/**
* Invoked every time a new chunk of the new DB file is downloaded, with the total size of the file and how much has
* already been downloaded.
* @param $olderDbExists - Indicates if an older DB file already exists when this method is called
*/
public function handleProgress(int $total, int $downloaded, bool $olderDbExists): void;
}

0 comments on commit b8ac9f3

Please sign in to comment.