Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion medcat-v2/medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,8 @@ def load_cdb(cls, model_pack_path: str) -> CDB:
@classmethod
def load_addons(
cls, model_pack_path: str,
addon_config_dict: Optional[dict[str, dict]] = None
addon_config_dict: Optional[dict[str, dict]] = None,
addon_types: Optional[list[Type[AddonComponent]]] = None,
) -> list[tuple[str, AddonComponent]]:
"""Load addons based on a model pack path.

Expand All @@ -917,6 +918,9 @@ def load_addons(
For instance,
`{"meta_cat.Subject": {'general': {'device': 'cpu'}}}`
would apply to the specific MetaCAT.
addon_type (Optional[list[Type[AddonComponent]]]):
The types of adddons to include. If not specified, all
addons will be loaded. Defaults to None.

Returns:
List[tuple(str, AddonComponent)]: list of pairs of adddon names the addons.
Expand All @@ -931,6 +935,25 @@ def load_addons(
components_folder, folder_name))
and folder_name.startswith(AddonComponent.NAME_PREFIX)
]
if addon_types is not None:
# filter based on specified addon types
had_before = len(addon_paths_and_names)
expected_folder_names = [
addon_type.get_folder_name_for_addon_and_name(
addon_type.addon_type, "")
for addon_type in addon_types
]
addon_paths_and_names = [
(addon_path, addon)
for addon_path, addon in addon_paths_and_names
if any(
os.path.basename(addon_path).startswith(expected_prefix)
for expected_prefix in expected_folder_names
)
]
logger.debug(
"Filtered %d addon paths down to %d from based on %s",
had_before, len(addon_paths_and_names), addon_types)
loaded_addons = [
addon for addon_path, addon_name in addon_paths_and_names
if isinstance(addon := (
Expand Down
6 changes: 2 additions & 4 deletions medcat-v2/medcat/components/addons/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ class AddonComponent(BaseComponent, Protocol):
"""Base/abstract addon component class."""
NAME_PREFIX: str = "addon_"
NAME_SPLITTER: str = "."
# NOTE: need to implement
addon_type: str
config: ComponentConfig

@property
def addon_type(self) -> str:
pass

def is_core(self) -> bool:
return False

Expand Down
18 changes: 18 additions & 0 deletions medcat-v2/tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,28 @@ def test_can_load_saved(self):

def test_can_load_meta_cat(self):
addons = cat.CAT.load_addons(self.mpp)
self.assert_has_one_meta_cat(addons)

def assert_has_one_meta_cat(self, addons: list[AddonComponent]):
self.assertEqual(len(addons), 1)
_, addon = addons[0]
self.assertIsInstance(addon, MetaCATAddon)

def test_can_filter_addons_empty(self):
# NONE -> empty
addons = cat.CAT.load_addons(self.mpp, addon_types=[])
self.assertFalse(addons)

def test_can_filter_addons_non_existing(self):
from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon
addons = cat.CAT.load_addons(self.mpp, addon_types=[RelCATAddon])
self.assertFalse(addons)

def test_can_filter_addons_meta_cat(self):
# only meta cat -> same as regular
addons = cat.CAT.load_addons(self.mpp, addon_types=[MetaCATAddon])
self.assert_has_one_meta_cat(addons)

def test_can_load_meta_cat_with_addon_cnf(self, seed: int = -41):
mc: MetaCATAddon = cat.CAT.load_addons(
self.mpp, addon_config_dict={
Expand Down