import uuid import csv import io import time from base64 import b64encode from http import HTTPStatus from arches.app.models import models from arches.app.models.tile import Tile from arches.app.search.elasticsearch_dsl_builder import Query from arches.app.search.mappings import TERMS_INDEX, CONCEPTS_INDEX, RESOURCES_INDEX from arches.app.search.search_engine_factory import SearchEngineFactory from arches.app.search.search_export import SearchResultsExporter from arches.app.utils.skos import SKOSReader from django.contrib.auth.models import User from django.test.client import RequestFactory from django.urls import get_script_prefix, reverse, set_script_prefix from arches.app.views.api import SearchExport from tests.base_test import ArchesTestCase from tests.utils.search_test_utils import sync_es # these tests can be run from the command line via # python manage.py test tests.search.search_export_tests --settings="tests.test_settings" class SearchExportTests(ArchesTestCase): graph_fixtures = ["Search Test Model"] search_model_graphid = "d291a445-fa5f-11e6-afa8-14109fd34195" search_model_cultural_period_nodeid = "7a182580-fa60-11e6-96d1-14109fd34195" search_model_cultural_period_nodename = "Cultural Period Concept" search_model_name_nodeid = "2fe14de3-fa61-11e6-897b-14109fd34195" def setUp(self): super().setUp() se = SearchEngineFactory().create() sync_es(se) @classmethod def setUpTestData(cls): super().setUpTestData() se = SearchEngineFactory().create() q = Query(se=se) for indexname in [TERMS_INDEX, CONCEPTS_INDEX, RESOURCES_INDEX]: q.delete(index=indexname, refresh=True) cls.factory = RequestFactory() cls.user = User.objects.create_user( "unprivileged_user", "unprivileged_user@test.com", "test" ) cls.test_resourceinstanceid = uuid.uuid4() models.ResourceInstance.objects.get_or_create( graph_id=cls.search_model_graphid, resourceinstanceid=cls.test_resourceinstanceid, ) tile_data = {} tile_data[cls.search_model_name_nodeid] = { "en": {"value": "Etiwanda Avenue Street Trees", "direction": "ltr"} } new_tile = Tile( resourceinstance_id=cls.test_resourceinstanceid, data=tile_data, nodegroup_id=cls.search_model_name_nodeid, ) new_tile.save() skos = SKOSReader() rdf = skos.read_file("tests/fixtures/data/concept_label_test_scheme.xml") ret = skos.save_concepts_from_skos(rdf) skos = SKOSReader() rdf = skos.read_file("tests/fixtures/data/concept_label_test_collection.xml") ret = skos.save_concepts_from_skos(rdf) cls.valueid = "dadaeee5-57ef-409d-a6cf-98d332fdada8" cultural_period_tile = Tile( data={cls.search_model_cultural_period_nodeid: cls.valueid}, nodegroup_id=cls.search_model_cultural_period_nodeid, resourceinstance_id=cls.test_resourceinstanceid, ) cultural_period_tile.save(index=False) cultural_period_tile.index() # TODO: create geospatial test data # Without some sleep, the next class's setup might fail when # get_resource_model_label() raises Node.DoesNotExist time.sleep(1) @classmethod def tearDownClass(cls): se = SearchEngineFactory().create() q = Query(se=se) for indexname in [TERMS_INDEX, CONCEPTS_INDEX, RESOURCES_INDEX]: q.delete(index=indexname, refresh=True) super().tearDownClass() def test_search_export_no_request(self): """Test SearchResultsExporter without search request""" with self.assertRaisesMessage(Exception, "Need to pass in a search request"): SearchResultsExporter() def test_invalid_format(self): """Test SearchResultsExporter with invalid format for shapefile export""" request = self.factory.get( "/search?tiles=true&export=true&format=shp&compact=false" ) request.user = self.user with self.assertRaisesMessage( Exception, "Results must be compact to export to shapefile" ): SearchResultsExporter(search_request=request) def test_export_to_csv(self): request = self.factory.get("/search?tiles=True&export=True&format=tilecsv") request.user = self.user exporter = SearchResultsExporter(search_request=request) result, _ = exporter.export(format="tilecsv", report_link="false") self.assertIn(".csv", result[0]["name"]) def test_write_export_to_zip(self): request = self.factory.get("/search?tiles=True&export=True&format=tilecsv") request.user = self.user exporter = SearchResultsExporter(search_request=request) result, info = exporter.export(format="tilecsv", report_link="false") uuid = exporter.write_export_zipfile(result, info, "test") self.assertIsNotNone(uuid) # def test_export_to_shp(self): # """Test exporting search results to SHP format""" # request = self.factory.get('/search?tiles=True&export=True&format=shp&compact=True') # request.user = self.user # exporter = SearchResultsExporter(search_request=request) # result = exporter.export(format='shp', report_link='false') # self.assertTrue(any('.shp' in file['name'] for file in result)) # def test_export_to_geojson(self): # """Test exporting search results to GeoJSON format""" # request = self.factory.get('/search?tiles=True&export=True&format=geojson') # request.user = self.user # exporter = SearchResultsExporter(search_request=request) # result, _ = exporter.export(format='geojson', report_link='false') # self.assertEqual(result['type'], 'FeatureCollection') # def test_link_append(self): # """Test appending report link to export""" # request = self.factory.get('/search?tiles=True&export=True&format=tilecsv&reportlink=true') # request.user = self.user # exporter = SearchResultsExporter(search_request=request) # result, _ = exporter.export(format='tilecsv', report_link='true') # self.assertIn('Link', result[0]['outputfile'].getvalue()) def test_export_to_csv_with_system_values(self): """Test exporting search results to CSV with system values included""" request = self.factory.get( "/search?tiles=True&export=True&format=tilecsv&exportsystemvalues=true" ) request.user = self.user exporter = SearchResultsExporter(search_request=request) result, _ = exporter.export(format="tilecsv", report_link="false") self.assertIn(".csv", result[0]["name"]) csv_content = result[0]["outputfile"].getvalue() csv_reader = csv.DictReader(io.StringIO(csv_content)) cultural_period_column_name = self.search_model_cultural_period_nodename for row in csv_reader: self.assertTrue(len(row) > 1, f"{len(row)} column(s) in csv row: {row}") cultural_period_value = row[cultural_period_column_name] self.assertTrue( is_valid_uuid(cultural_period_value), f"Expected UUID, got {cultural_period_value}", ) break def test_export_to_csv_without_system_values(self): """Test exporting search results to CSV without system values""" request = self.factory.get( "/search?tiles=True&export=True&format=tilecsv&exportsystemvalues=false" ) request.user = self.user exporter = SearchResultsExporter(search_request=request) result, _ = exporter.export(format="tilecsv", report_link="false") self.assertIn(".csv", result[0]["name"]) csv_content = result[0]["outputfile"].getvalue() csv_reader = csv.DictReader(io.StringIO(csv_content)) cultural_period_column_name = self.search_model_cultural_period_nodename for row in csv_reader: self.assertTrue(len(row) > 1, f"{len(row)} columns in csv row: {row}") cultural_period_value = row[cultural_period_column_name] self.assertFalse( is_valid_uuid(cultural_period_value), f"Expected non-UUID, got {cultural_period_value}", ) break def test_login_via_basic_auth_good(self): auth_string = "Basic " + b64encode(b"admin:admin").decode("utf-8") request = RequestFactory().get( reverse("api_export_results"), HTTP_AUTHORIZATION=auth_string, ) request.user = User.objects.get(username="anonymous") response = SearchExport().get(request) self.assertEqual(request.user.username, "admin") self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND) def test_login_via_basic_auth_rate_limited(self): auth_string = "Basic " + b64encode(b"admin:admin").decode("utf-8") request = RequestFactory().get( reverse("api_export_results"), HTTP_AUTHORIZATION=auth_string, # In reality this would be added by django_ratelimit. QUERY_STRING="limited=True", ) request.user = User.objects.get(username="anonymous") response = SearchExport().get(request) self.assertEqual(request.user.username, "anonymous") self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND) def test_login_via_basic_auth_invalid(self): bad_auth_string = "Basic " + b64encode(b"admin:garbage").decode("utf-8") request = RequestFactory().get( reverse("api_export_results"), HTTP_AUTHORIZATION=bad_auth_string, ) request.user = User.objects.get(username="anonymous") response = SearchExport().get(request) self.assertEqual(request.user.username, "anonymous") self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) def test_script_prefix(self): prefix = get_script_prefix() set_script_prefix("/nginx") self.addCleanup(set_script_prefix, prefix) request = self.factory.get("/search?tiles=True&export=True&format=tilecsv") request.user = self.user exporter = SearchResultsExporter(search_request=request) exporter.export(format="tilecsv", report_link="false") def is_valid_uuid(value, version=4): """Check if value is a valid UUID.""" try: uuid_obj = uuid.UUID(value, version=version) return str(uuid_obj) == value except ValueError: return False