diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 2825ee561..b6f89c26c 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -905,7 +905,8 @@ def load_cdb(cls, model_pack_path: str) -> CDB: @classmethod def load_addons( cls, model_pack_path: str, - addon_config_dict: Optional[dict[str, dict]] = None + addon_config_dict: Optional[dict[str, dict]] = None, + addon_types: Optional[list[Type[AddonComponent]]] = None, ) -> list[tuple[str, AddonComponent]]: """Load addons based on a model pack path. @@ -917,6 +918,9 @@ def load_addons( For instance, `{"meta_cat.Subject": {'general': {'device': 'cpu'}}}` would apply to the specific MetaCAT. + addon_type (Optional[list[Type[AddonComponent]]]): + The types of adddons to include. If not specified, all + addons will be loaded. Defaults to None. Returns: List[tuple(str, AddonComponent)]: list of pairs of adddon names the addons. @@ -931,6 +935,25 @@ def load_addons( components_folder, folder_name)) and folder_name.startswith(AddonComponent.NAME_PREFIX) ] + if addon_types is not None: + # filter based on specified addon types + had_before = len(addon_paths_and_names) + expected_folder_names = [ + addon_type.get_folder_name_for_addon_and_name( + addon_type.addon_type, "") + for addon_type in addon_types + ] + addon_paths_and_names = [ + (addon_path, addon) + for addon_path, addon in addon_paths_and_names + if any( + os.path.basename(addon_path).startswith(expected_prefix) + for expected_prefix in expected_folder_names + ) + ] + logger.debug( + "Filtered %d addon paths down to %d from based on %s", + had_before, len(addon_paths_and_names), addon_types) loaded_addons = [ addon for addon_path, addon_name in addon_paths_and_names if isinstance(addon := ( diff --git a/medcat-v2/medcat/components/addons/addons.py b/medcat-v2/medcat/components/addons/addons.py index c173fc19b..ef5103210 100644 --- a/medcat-v2/medcat/components/addons/addons.py +++ b/medcat-v2/medcat/components/addons/addons.py @@ -13,12 +13,10 @@ class AddonComponent(BaseComponent, Protocol): """Base/abstract addon component class.""" NAME_PREFIX: str = "addon_" NAME_SPLITTER: str = "." + # NOTE: need to implement + addon_type: str config: ComponentConfig - @property - def addon_type(self) -> str: - pass - def is_core(self) -> bool: return False diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index e6a366670..6c4565a52 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -524,10 +524,28 @@ def test_can_load_saved(self): def test_can_load_meta_cat(self): addons = cat.CAT.load_addons(self.mpp) + self.assert_has_one_meta_cat(addons) + + def assert_has_one_meta_cat(self, addons: list[AddonComponent]): self.assertEqual(len(addons), 1) _, addon = addons[0] self.assertIsInstance(addon, MetaCATAddon) + def test_can_filter_addons_empty(self): + # NONE -> empty + addons = cat.CAT.load_addons(self.mpp, addon_types=[]) + self.assertFalse(addons) + + def test_can_filter_addons_non_existing(self): + from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon + addons = cat.CAT.load_addons(self.mpp, addon_types=[RelCATAddon]) + self.assertFalse(addons) + + def test_can_filter_addons_meta_cat(self): + # only meta cat -> same as regular + addons = cat.CAT.load_addons(self.mpp, addon_types=[MetaCATAddon]) + self.assert_has_one_meta_cat(addons) + def test_can_load_meta_cat_with_addon_cnf(self, seed: int = -41): mc: MetaCATAddon = cat.CAT.load_addons( self.mpp, addon_config_dict={