Skip to content

Commit

Permalink
Merge pull request #10988 from archesproject/jtw/fix-basic-auth
Browse files Browse the repository at this point in the history
Fix basic HTTP authorization via search export GET #10986
  • Loading branch information
aj-he authored Jun 5, 2024
2 parents 115603a + abe20d3 commit 32f8afa
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
10 changes: 6 additions & 4 deletions arches/app/views/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from base64 import b64decode
import importlib
import json
import logging
Expand All @@ -7,8 +6,9 @@
import sys
import uuid
import traceback
from io import StringIO
from oauth2_provider.views import ProtectedResourceView
from base64 import b64decode
from http import HTTPStatus
from pyld.jsonld import compact, frame, from_rdf
from rdflib import RDF
from rdflib.namespace import SKOS, DCTERMS
Expand Down Expand Up @@ -42,7 +42,7 @@
from arches.app.views.tile import TileData as TileView
from arches.app.views.resource import RelatedResourcesView, get_resource_relationship_types
from arches.app.utils.skos import SKOSWriter
from arches.app.utils.response import JSONResponse
from arches.app.utils.response import JSONResponse, JSONErrorResponse
from arches.app.utils.decorators import can_read_concept, group_required
from arches.app.utils.betterJSONSerializer import JSONSerializer, JSONDeserializer
from arches.app.utils.data_management.resources.exporter import ResourceExporter
Expand Down Expand Up @@ -1061,13 +1061,15 @@ def get(self, request):
download_limit = settings.SEARCH_EXPORT_IMMEDIATE_DOWNLOAD_THRESHOLD
format = request.GET.get("format", "tilecsv")
report_link = request.GET.get("reportlink", False)
if "HTTP_AUTHORIZATION" in request.META and not request.get("limited", False):
if "HTTP_AUTHORIZATION" in request.META and not request.GET.get("limited", False):
request_auth = request.META.get("HTTP_AUTHORIZATION").split()
if request_auth[0].lower() == "basic":
user_cred = b64decode(request_auth[1]).decode().split(":")
user = authenticate(username=user_cred[0], password=user_cred[1])
if user is not None:
request.user = user
else:
return JSONErrorResponse(status=HTTPStatus.UNAUTHORIZED)
exporter = SearchResultsExporter(search_request=request)
export_files, export_info = exporter.export(format, report_link)
if format == "geojson" and total <= download_limit:
Expand Down
1 change: 1 addition & 0 deletions releases/6.2.8.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Arches 6.2.8 release notes
### Bug Fixes and Enhancements

- Preserve tile sortorder during `import_business_data` (from JSON) #10874
- Fix authentication via SearchExport GET request #10986

### Dependency changes:
```
Expand Down
48 changes: 48 additions & 0 deletions tests/search/search_export_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from base64 import b64encode
from http import HTTPStatus

from django.contrib.auth.models import User
from django.test.client import RequestFactory
from django.urls import reverse

from arches.app.views.api import SearchExport
from tests.base_test import ArchesTestCase

# these tests can be run from the command line via
# python manage.py test tests.search.search_export_tests --settings="tests.test_settings"

class SearchExportTests(ArchesTestCase):
def test_login_via_basic_auth_good(self):
auth_string = "Basic " + b64encode(b"admin:admin").decode("utf-8")
request = RequestFactory().get(
reverse("api_export_results"),
HTTP_AUTHORIZATION=auth_string,
)
request.user = User.objects.get(username="anonymous")
response = SearchExport().get(request)
self.assertEqual(request.user.username, "admin")
self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)

def test_login_via_basic_auth_rate_limited(self):
auth_string = "Basic " + b64encode(b"admin:admin").decode("utf-8")
request = RequestFactory().get(
reverse("api_export_results"),
HTTP_AUTHORIZATION=auth_string,
# In reality this would be added by django_ratelimit.
QUERY_STRING="limited=True",
)
request.user = User.objects.get(username="anonymous")
response = SearchExport().get(request)
self.assertEqual(request.user.username, "anonymous")
self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)

def test_login_via_basic_auth_invalid(self):
bad_auth_string = "Basic " + b64encode(b"admin:garbage").decode("utf-8")
request = RequestFactory().get(
reverse("api_export_results"),
HTTP_AUTHORIZATION=bad_auth_string,
)
request.user = User.objects.get(username="anonymous")
response = SearchExport().get(request)
self.assertEqual(request.user.username, "anonymous")
self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED)

0 comments on commit 32f8afa

Please sign in to comment.