diff --git a/medcat-trainer/envs/env b/medcat-trainer/envs/env index 1f2a45c2f..044eaa2ed 100644 --- a/medcat-trainer/envs/env +++ b/medcat-trainer/envs/env @@ -31,7 +31,7 @@ UNIQUE_DOC_NAMES_IN_DATASETS=True MAX_DATASET_SIZE=10000 ### Solr Concept Search Conf ### -CONCEPT_SEARCH_SERVICE_HOST=localhost +CONCEPT_SEARCH_SERVICE_HOST=solr CONCEPT_SEARCH_SERVICE_PORT=8983 ### DB backup dir ### diff --git a/medcat-trainer/webapp/api/api/admin/actions.py b/medcat-trainer/webapp/api/api/admin/actions.py index 1ec12d16f..07541b5d5 100644 --- a/medcat-trainer/webapp/api/api/admin/actions.py +++ b/medcat-trainer/webapp/api/api/admin/actions.py @@ -11,7 +11,7 @@ from rest_framework.exceptions import PermissionDenied from api.models import AnnotatedEntity, MetaAnnotation, EntityRelation, Document, ConceptDB -from api.solr_utils import drop_collection, import_all_concepts +from api.solr_utils import drop_collection, import_all_concepts, solr_collection_name from api.utils import clear_cdb_cnf_addons from medcat.cdb import CDB @@ -381,7 +381,7 @@ def reset_cdb_filters(modeladmin, request, queryset): def import_concepts(modeladmin, request, queryset): for concept_db in queryset: - logger.info(f'Importing concepts for collection {concept_db.name}_id_{concept_db.id}') + logger.info(f'Importing concepts for collection {solr_collection_name(concept_db)}') import_concepts_from_cdb(concept_db.id) diff --git a/medcat-trainer/webapp/api/api/solr_utils.py b/medcat-trainer/webapp/api/api/solr_utils.py index 1ff34a0f6..34b76b1b5 100644 --- a/medcat-trainer/webapp/api/api/solr_utils.py +++ b/medcat-trainer/webapp/api/api/solr_utils.py @@ -26,6 +26,18 @@ } +def solr_collection_name(cdb_model: ConceptDB) -> str: + """ + Solr-safe collection name for a ConceptDB. + + The SOLR API specifically states this so just replace with underscores. + "collection names must consist entirely of periods, underscores, + hyphens, and alphanumerics as well not start with a hyphen" + """ + safe_name = re.sub(r'[^a-zA-Z0-9._-]', '_', cdb_model.name).lstrip('-') + return f'{safe_name}_id_{cdb_model.id}' + + @tracer.start_as_current_span("cache_solr_collection_schema_types", attributes=solr_trace_attributes) def _cache_solr_collection_schema_types(collection): url = f'http://{SOLR_HOST}:{SOLR_PORT}/solr/{collection}/schema' @@ -85,7 +97,7 @@ def search_collection(cdbs: List[int], raw_query: str): uniq_results_map = {} for cdb in cdbs: cdb_model = ConceptDB.objects.get(id=cdb) - collection_name = f'{cdb_model.name}_id_{cdb_model.id}' + collection_name = solr_collection_name(cdb_model) trace.get_current_span().add_event("Searching collection for CDB", attributes={"collection_name": collection_name, "cdb_id": cdb_model.id, "cdb_name": cdb_model.name, "query": raw_query}) if collection_name not in SOLR_INDEX_SCHEMA: @@ -123,7 +135,7 @@ def search_collection(cdbs: List[int], raw_query: str): @tracer.start_as_current_span("import_all_concepts", attributes=solr_trace_attributes) def import_all_concepts(cdb: CDB, cdb_model: ConceptDB): - collection_name = f'{cdb_model.name}_id_{cdb_model.id}' + collection_name = solr_collection_name(cdb_model) base_url = f'http://{SOLR_HOST}:{SOLR_PORT}/solr' trace.get_current_span().add_event("Importing all concepts for CDB", attributes={ @@ -138,7 +150,6 @@ def import_all_concepts(cdb: CDB, cdb_model: ConceptDB): collections = json.loads(resp.text)['collections'] if collection_name in collections: - # delete collection url = f'{base_url}/admin/collections?action=DELETE&name={collection_name}' requests.get(url) @@ -171,7 +182,7 @@ def import_all_concepts(cdb: CDB, cdb_model: ConceptDB): def drop_collection(cdb_model: ConceptDB): - collection_name = f'{cdb_model.name}_id_{cdb_model.id}' + collection_name = solr_collection_name(cdb_model) base_url = f'http://{SOLR_HOST}:{SOLR_PORT}/solr' url = f'{base_url}/admin/collections?action=DELETE&name={collection_name}' resp = requests.get(url) @@ -189,10 +200,10 @@ def ensure_concept_searchable(cui, cdb: CDB, cdb_model: ConceptDB): cdb: the MedCAT CDB where the cui can be found cdb_model: the associated Django model instance for the CDB. """ - collection = f'{cdb_model.name}_id_{cdb_model.id}' base_url = f'http://{SOLR_HOST}:{SOLR_PORT}/solr' url = f'{base_url}/admin/collections?action=LIST' resp = requests.get(url) + collection = solr_collection_name(cdb_model) if resp.status_code == 200: collections = json.loads(resp.text)['collections'] data = [_concept_dct(cui, cdb, cdb.cui2info[cui])] diff --git a/medcat-trainer/webapp/api/api/tests/test_solr_utils.py b/medcat-trainer/webapp/api/api/tests/test_solr_utils.py index dc38789a3..65a95581f 100644 --- a/medcat-trainer/webapp/api/api/tests/test_solr_utils.py +++ b/medcat-trainer/webapp/api/api/tests/test_solr_utils.py @@ -124,6 +124,25 @@ def fake_get(url, *args, **kwargs): self.assertEqual(response.data['results'][0]['cui'], 'C999') +@override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') +class SolrCollectionNameTests(TestCase): + def test_replaces_spaces_and_other_invalid_characters(self): + cdb = ConceptDB(name='test pack with space_CDB', cdb_file='space_cdb.dat') + cdb.save(skip_load=True) + self.assertEqual( + solr_utils.solr_collection_name(cdb), + f'test_pack_with_space_CDB_id_{cdb.id}', + ) + + def test_strips_leading_hyphens(self): + cdb = ConceptDB(name='-leading-hyphen', cdb_file='hyphen_cdb.dat') + cdb.save(skip_load=True) + self.assertEqual( + solr_utils.solr_collection_name(cdb), + f'leading-hyphen_id_{cdb.id}', + ) + + @override_settings(MEDIA_ROOT='/tmp/mct-tests-solr') class HelperFunctionTests(TestCase): def test_process_result_response_deduplicates_by_cui(self):