diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py index 6a58b3a0d..b7fba6a9b 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py @@ -70,13 +70,46 @@ class EmbeddingLinking(Linking): """Choose a device for the linking model to be stored. If None then an appropriate GPU device that is available will be chosen""" context_window_size: int = 14 - """Choose the window size to get context vectors.""" + """Choose the window size to get context vectors. In a trained model + if you increase the context window after training then performance will + degrade significantly.""" use_ner_link_candidates: bool = True """Link candidates are provided by some NER steps. This will flag if - you want to trust them or not.""" + you want to trust them or not. A good guideline is if you've trained + on data from the same distribution then this is probably best set to True. + If you have no training data from the same source distribution then it MIGHT + be better set to false.""" + append_to_ner_link_candidates: bool = False + """If `use_ner_link_candidates` is enabled, generate additional + candidates and append them to existing NER candidates instead of only + generating for entities that have none. This will often result in a slight + increase in recall, and precision.""" + use_pre_inference: bool = True + """Whether to use the pre-inference step to filter candidates before + calculating similarities. This can speed up inference by only calculating + similarities for candidates that are likely to be correct based direct on word + matching.""" learning_rate: float = 1e-4 """Learning rate for training the embedding linker. Only used if the embedding linker is trainable.""" weight_decay: float = 0.01 """Weight decay for training the embedding linker. Only used if the embedding linker is trainable.""" + multiple_predictions_per_detected_entity: bool = False + """Whether to allow multiple predictions per detected entity. If False, only + the highest scoring candidate will be returned for each entity. If True, all + candidates that exceed the similarity thresholds will be returned. This can be + useful if you want to return multiple CUIs for an entity, but can also lead to + more false positives.""" + pre_inference_top_k_sampling: int = 1 + """When using pre-inference to filter candidates, how many names to then add + their related CUIs as potential candidates. Higher numbers will increase recall + but also increase inference time, and reduce precision. This is influenced by + `short_similarity_threshold`, i.e. pass the top k samples over the threshold + for inference.""" + inference_top_k_sampling: int = 1 + """At the inference step, after calculating similarity scores, how many candidates + to keep for each entity. Higher numbers will increase recall but also increase + inference time, and often reduce precision. This is influenced by + `long_similarity_threshold`, i.e. take the top k samples over the threshold. This + will be ignored if `multiple_predictions_per_detected_entity` is set to False.""" diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py index 60ac8e6fd..53b52a5cb 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py @@ -33,18 +33,21 @@ def __init__( self, cdb: CDB, config: Config, + tokenizer: BaseTokenizer, model_init_kwargs: Optional[dict[str, Any]] = None, ) -> None: """Initializes the embedding linker with a CDB and configuration. Args: cdb (CDB): The concept database to use. config (Config): The base config. + tokenizer (BaseTokenizer): The tokenizer to use. model_init_kwargs (Optional[dict[str, Any]]): Explicit kwargs that override linker defaults. """ super().__init__() self.cdb = cdb self.config = config + self.tokenizer = tokenizer if not isinstance(config.components.linking, EmbeddingLinking): raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking @@ -369,7 +372,7 @@ def _set_filters(self) -> None: def _disambiguate_by_cui( self, cui_candidates: list[str], scores: Tensor - ) -> tuple[str, float]: + ) -> list[tuple[str, float]]: """Disambiguate a detected concept by a list of potential cuis Args: cuis (list[str]): Potential cuis @@ -377,19 +380,71 @@ def _disambiguate_by_cui( scores (Tensor): Scores for the detected cui2info concepts similarity cui_keys (list[str]): idx_to_cui inverse Returns: - tuple[str, float]: - The CUI and its similarity + list[tuple[str, float]]: + The selected CUIs and their similarities. """ cui_idxs = [ self._cui_to_idx[cui] for cui in cui_candidates if cui in self._cui_to_idx ] + if not cui_idxs: + return [] + candidate_scores = scores[cui_idxs] + + if self.cnf_l.multiple_predictions_per_detected_entity: + threshold = self.cnf_l.long_similarity_threshold + selected_mask = candidate_scores >= threshold + selected_positions = torch.nonzero(selected_mask, as_tuple=True)[0] + + return [ + ( + self._cui_keys[cui_idxs[pos]], + float(candidate_scores[pos].item()), + ) + for pos in selected_positions.tolist() + ] + candidate_idx = int(torch.argmax(candidate_scores).item()) best_idx = cui_idxs[candidate_idx] - predicted_cui = self._cui_keys[best_idx] - similarity = float(candidate_scores[candidate_idx].item()) - return predicted_cui, similarity + return [ + ( + self._cui_keys[best_idx], + float(candidate_scores[candidate_idx].item()), + ) + ] + + def _get_predictions_from_names( + self, + selected_name_idxs: list[int], + row_scores: Tensor, + cui_scores_row: Tensor, + name_to_cuis: Optional[dict[str, list[str]]] = None, + ) -> list[tuple[str, float]]: + """Retrieve all cuis from the candidate names + + Optional - use name to cuis that has already been generated + with link candidates + """ + cuis_set: set[str] = set() + for name_idx in selected_name_idxs: + selected_name = self._name_keys[name_idx] + if name_to_cuis is None: + cuis_set.update( + self.cdb.name2info[selected_name]["per_cui_status"].keys() + ) + else: + cuis_set.update(name_to_cuis[selected_name]) + cuis = list(cuis_set) + if len(cuis) == 1: + # If there's only one possible cui from the names + # We don't get the similarity for the longest cui score + # Just for speed - this may alter performance if the longest name + # for the cui doesn't meet it's threshold + similarity = max(float(row_scores[name_idx].item()) + for name_idx in selected_name_idxs) + return [(cuis[0], similarity)] + return self._disambiguate_by_cui(cuis, cui_scores_row) def _inference( self, doc: MutableDocument, entities: list[MutableEntity] @@ -409,9 +464,9 @@ def _inference( # score all detected contexts vs all names names_scores = detected_context_vectors @ self.names_context_matrix.T cui_scores = detected_context_vectors @ self.cui_context_matrix.T - sorted_indices = torch.argsort(names_scores, dim=1, descending=True) for i, entity in enumerate(entities): + predictions: list[tuple[str, float]] = [] link_candidates = entity.link_candidates if self.config.components.linking.filter_before_disamb: link_candidates = [ @@ -419,7 +474,13 @@ def _inference( for cui in link_candidates if self.cnf_l.filters.check_filters(cui) ] - if len(link_candidates) == 1: + # TODO: Is this "not" correct? if I skip pre inference I don't care + # about the link candidates? + if ( + len(link_candidates) == 1 and + (self.cnf_l.use_pre_inference or + self.cnf_l.use_ner_link_candidates) + ): best_idx = self._cui_to_idx[link_candidates[0]] predicted_cui = link_candidates[0] if best_idx < 0 or best_idx >= cui_scores.shape[1]: @@ -431,13 +492,19 @@ def _inference( cui_scores.shape[1], ) continue - similarity = cui_scores[i, best_idx].item() - elif len(link_candidates) > 1: + similarity = float(cui_scores[i, best_idx].item()) + predictions = [(predicted_cui, similarity)] + elif ( + len(link_candidates) > 1 and + (self.cnf_l.use_pre_inference or + self.cnf_l.use_ner_link_candidates) + ): + # get all possible names from candidate cuis name_to_cuis = defaultdict(list) for cui in link_candidates: for name in self.cdb.cui2info[cui]["names"]: name_to_cuis[name].append(cui) - + # their position within matricies name_idxs = [ self._name_to_idx[name] for name in name_to_cuis @@ -451,39 +518,83 @@ def _inference( entity.detected_name, ) continue + # get all the scores for the names indexed_scores = names_scores[i, name_idxs] best_local_pos = int(torch.argmax(indexed_scores).item()) best_global_idx = name_idxs[best_local_pos] - similarity = names_scores[i, best_global_idx].item() - best_name = self._name_keys[best_global_idx] - cuis = name_to_cuis[best_name] - if len(cuis) == 1: - predicted_cui = cuis[0] - else: - predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - else: - row_sorted = sorted_indices[i] # sorted candidate indices for entity i + similarity = float(names_scores[i, best_global_idx].item()) + selected_name_idxs = [ + name_idx + for name_idx in name_idxs + if float(names_scores[i, name_idx].item()) >= + self.cnf_l.long_similarity_threshold + ] + # if no names pass the threshold - no cuis will + # skip this detected entity + if not selected_name_idxs: + continue - # Find the first candidate in this row with CUIs - first_true_pos = int( - torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][ - 0 - ].item() + predictions = self._get_predictions_from_names( + selected_name_idxs, + names_scores[i], + cui_scores[i, :], + name_to_cuis, + ) + else: + # if there are no link candidates + # or you don't want to use them + row_scores = names_scores[i] + # get all names that pass the threshold + selected_mask = self._valid_names & ( + row_scores >= self.cnf_l.long_similarity_threshold ) + selected_name_idxs = torch.nonzero( + selected_mask, as_tuple=True + )[0].tolist() + # if none pass the threshold + if not selected_name_idxs: + continue - # Get global index + name - top_name_idx = int(row_sorted[first_true_pos].item()) - similarity = names_scores[i, top_name_idx].item() - detected_name = self._name_keys[top_name_idx] - cuis = list(self.cdb.name2info[detected_name]["per_cui_status"].keys()) + # if there are too many, take the top k to reduce processing time + # this is a trade off between compute time and predictive power + # as k increases, processing time increases + if len(selected_name_idxs) > self.cnf_l.inference_top_k_sampling: + selected_scores = row_scores[selected_name_idxs] + topk_positions = torch.topk( + selected_scores, k=self.cnf_l.inference_top_k_sampling + ).indices.tolist() + selected_name_idxs = [ + selected_name_idxs[pos] for pos in topk_positions + ] + + + predictions = self._get_predictions_from_names( + selected_name_idxs, + row_scores, + cui_scores[i, :], + ) - predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - if not self.cnf_l.filters.check_filters(predicted_cui): - continue - if self._check_similarity(similarity): - entity.cui = predicted_cui - entity.context_similarity = similarity - yield entity + for predicted_cui, predicted_similarity in predictions: + # check if the predicted cui passes the filters + if not self.cnf_l.filters.check_filters(predicted_cui): + continue + # This check is useful when there's a single link candidate + # Or only a single prediction that's been disambiguated + if not self._check_similarity(predicted_similarity): + continue + if self.cnf_l.multiple_predictions_per_detected_entity: + # create a barebones entity that has what is requried + ent = self.tokenizer.create_entity( + doc, + entity.base.start_index, + entity.base.end_index, + entity.detected_name, + ) + else: + ent = entity + ent.cui = predicted_cui + ent.context_similarity = predicted_similarity + yield ent def _check_similarity(self, context_similarity: float) -> bool: if self.cnf_l.long_similarity_threshold: @@ -503,7 +614,10 @@ def _build_context_matrices(self) -> None: ) def _generate_link_candidates( - self, doc: MutableDocument, entities: list[MutableEntity] + self, + doc: MutableDocument, + entities: list[MutableEntity], + append_to_existing: bool = False, ) -> None: """Generate link candidates for each detected entity based on context vectors with size 0. Compare to names to get the most @@ -523,22 +637,55 @@ def _generate_link_candidates( # valid names via filtering and contain at least 1 cui valid_mask = self._valid_names[row_sorted] - if self.cnf_l.short_similarity_threshold > 0: - # thresholded selection + valid_positions = torch.nonzero(valid_mask, as_tuple=True)[0] + + if ( + self.cnf_l.short_similarity_threshold > 0 and + self.cnf_l.pre_inference_top_k_sampling > 0 + ): + # Require candidates to satisfy BOTH criteria: + # (a) score above threshold and (b) within top-k valid names. + valid_scores = row_scores[valid_positions] + k = min(self.cnf_l.pre_inference_top_k_sampling, len(valid_positions)) + if k > 0: + topk_rel = torch.topk(valid_scores, k=k).indices + topk_positions = valid_positions[topk_rel] + keep_mask = ( + row_scores[topk_positions] >= + self.cnf_l.short_similarity_threshold + ) + valid_positions = topk_positions[keep_mask] + else: + valid_positions = valid_positions[:0] + elif self.cnf_l.short_similarity_threshold > 0: + # Threshold-only mode. above_thresh_mask = row_scores >= self.cnf_l.short_similarity_threshold selected_mask = valid_mask & above_thresh_mask valid_positions = torch.nonzero(selected_mask, as_tuple=True)[0] + elif self.cnf_l.pre_inference_top_k_sampling > 0: + # Top-k-only mode among valid names. + valid_scores = row_scores[valid_positions] + k = min(self.cnf_l.pre_inference_top_k_sampling, len(valid_positions)) + if k > 0: + topk_rel = torch.topk(valid_scores, k=k).indices + valid_positions = valid_positions[topk_rel] + else: + valid_positions = valid_positions[:0] else: - # just take the single best valid candidate - first_valid = torch.nonzero(valid_mask, as_tuple=True)[0][:1] - valid_positions = first_valid + # If neither criterion is enabled, keep only the best valid candidate. + valid_positions = valid_positions[:1] + # getting cuis from all valid names that pass the threshold and top-k for pos in valid_positions.tolist(): top_name_idx = int(row_sorted[pos].item()) detected_name = self._name_keys[top_name_idx] cuis.update(self.cdb.name2info[detected_name]["per_cui_status"].keys()) - - entity.link_candidates = list(cuis) + + if append_to_existing: + existing = set(entity.link_candidates) + entity.link_candidates = list(existing | cuis) + else: + entity.link_candidates = list(cuis) def _pre_inference( self, doc: MutableDocument @@ -547,9 +694,25 @@ def _pre_inference( avoid full inference step. If we want to calculate similarities, or not use link candidates then just return the entities""" all_ents = doc.ner_ents + # if we don't care to use pre inference just return all entities + # as they are + if not self.cnf_l.use_pre_inference: + return [], all_ents + + append_generated_to_ner = ( + self.cnf_l.use_ner_link_candidates + and self.cnf_l.append_to_ner_link_candidates + ) + if not self.cnf_l.use_ner_link_candidates: + # ignoring link candidates generated by NER + to_generate_link_candidates = all_ents + elif append_generated_to_ner: + # Keep NER candidates and append model-generated candidates. to_generate_link_candidates = all_ents else: + # here we only generate link candidates if they don't exist + # i.e. out of vocabulary to_generate_link_candidates = [ entity for entity in all_ents if not entity.link_candidates ] @@ -558,7 +721,11 @@ def _pre_inference( for entities in self._batch_data( to_generate_link_candidates, self.cnf_l.linking_batch_size ): - self._generate_link_candidates(doc, entities) + self._generate_link_candidates( + doc, + entities, + append_generated_to_ner + ) # filter out entities with no link candidates after thresholding filtered_ents = [ent for ent in all_ents if ent.link_candidates] @@ -569,6 +736,10 @@ def _pre_inference( le: list[MutableEntity] = [] to_infer: list[MutableEntity] = [] for entity in all_ents: + # if no candidates just skip it + if not entity.link_candidates: + continue + # TODO: Check if this is right now with multiple entites being possible if len(entity.link_candidates) == 1: # if the include filter exists and the only cui is in it if self.cnf_l.filters.check_filters(entity.link_candidates[0]): @@ -576,8 +747,6 @@ def _pre_inference( entity.context_similarity = 1 le.append(entity) continue - elif self.cnf_l.use_ner_link_candidates and not entity.link_candidates: - continue # it has to be inferred due to filters or number of link candidates to_infer.append(entity) return le, to_infer @@ -604,7 +773,7 @@ def predict_entities( for entities in self._batch_data(to_infer, self.cnf_l.linking_batch_size): le.extend(list(self._inference(doc, entities))) - return filter_linked_annotations(doc, le) + return filter_linked_annotations(doc, le, True) @property def names_context_matrix(self): @@ -627,4 +796,4 @@ def create_new_component( vocab: Vocab, model_load_path: Optional[str], ) -> "Linker": - return cls(cdb, cdb.config) + return cls(cdb, cdb.config, tokenizer) diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py index 4d3202f00..cb052e0ac 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py @@ -2,6 +2,7 @@ from medcat_embedding_linker.config import EmbeddingLinking from torch import Tensor from medcat.cdb import CDB +from medcat.components.types import TrainableComponent from medcat.config.config import Config, ComponentConfig from medcat.components.linking.vector_context_model import PerDocumentTokenCache from medcat.tokenizing.tokenizers import BaseTokenizer @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) -class Linker(StaticEmbeddingLinker, AbstractManualSerialisable): +class Linker(StaticEmbeddingLinker, AbstractManualSerialisable, TrainableComponent): """Trainable variant of the embedding linker. This class inherits inference and embedding behavior from Linker and provides method hooks for online/offline training. @@ -28,7 +29,10 @@ class Linker(StaticEmbeddingLinker, AbstractManualSerialisable): _MODEL_FOLDER_NAME = "trainable_embedding_model" _MODEL_STATE_FILE_NAME = "model_state.pt" - def __init__(self, cdb: CDB, config: Config) -> None: + def __init__(self, + cdb: CDB, + config: Config, + tokenizer: BaseTokenizer) -> None: if not isinstance(config.components.linking, EmbeddingLinking): raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking @@ -41,6 +45,7 @@ def __init__(self, cdb: CDB, config: Config) -> None: super().__init__( cdb, config, + tokenizer, model_init_kwargs=model_init_kwargs, ) self.training_batch: list[tuple] = [] @@ -407,7 +412,7 @@ def create_new_component( vocab: Vocab, model_load_path: Optional[str], ) -> "Linker": - return cls(cdb, cdb.config) + return cls(cdb, cdb.config, tokenizer) def serialise_to(self, folder_path: str) -> None: os.makedirs(folder_path, exist_ok=True) @@ -424,7 +429,8 @@ def deserialise_from( cls, folder_path: str, **init_kwargs ) -> "Linker": cdb = init_kwargs["cdb"] - linker = cls(cdb, cdb.config) + tokenizer = init_kwargs["tokenizer"] + linker = cls(cdb, cdb.config, tokenizer) model_state_path = os.path.join( folder_path, cls._MODEL_FOLDER_NAME, cls._MODEL_STATE_FILE_NAME diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py index fb08af4a7..fe60f0fa4 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -3,6 +3,7 @@ from medcat.storage.serialisables import AbstractSerialisable from torch import Tensor, nn from transformers import AutoModel, AutoTokenizer +from medcat_embedding_linker.config import EmbeddingLinking as LinkingConfig from tqdm import tqdm import json import logging @@ -23,14 +24,16 @@ class ModelForEmbeddingLinking(nn.Module): def __init__( self, embedding_model_name: str, + cnf_l: LinkingConfig, use_projection_layer: bool = False, - top_n_layers_to_unfreeze: int = -1, + top_n_layers_to_unfreeze: int = 0, device: Optional[Union[str, torch.device]] = None, ) -> None: super().__init__() self.language_model = AutoModel.from_pretrained(embedding_model_name) self.base_model_name = self.language_model.name_or_path + self.cnf_l = cnf_l self.use_projection_layer = use_projection_layer self.top_n_layers_to_unfreeze = top_n_layers_to_unfreeze @@ -86,6 +89,10 @@ def _freeze_all_parameters(self) -> None: param.requires_grad = True def unfreeze_top_n_lm_layers(self, n: int) -> None: + self.cnf_l.top_n_layers_to_unfreeze = n + self.top_n_layers_to_unfreeze = n + # Re-apply from a known baseline so repeated calls are deterministic. + self._freeze_all_parameters() # train all LM layers - each layer requires more data if n == -1: for param in self.language_model.parameters(): @@ -133,6 +140,7 @@ def save_pretrained(self, save_directory: Union[str, Path]) -> None: def from_pretrained( cls, path_or_model_name: Union[str, Path], + cnf_l: LinkingConfig, device: Optional[Union[str, torch.device]] = None, **kwargs, ) -> "ModelForEmbeddingLinking": @@ -147,7 +155,7 @@ def from_pretrained( config = json.load(f) config.update(kwargs) - model = cls(**config) + model = cls(cnf_l=cnf_l, **config) state_dict = torch.load(weights_path, map_location="cpu") model.load_state_dict(state_dict) model.to(target_device) @@ -156,6 +164,7 @@ def from_pretrained( # Hugging Face model id/path. model = cls( embedding_model_name=str(path_or_model_name), + cnf_l=cnf_l, device=target_device, **kwargs, ) @@ -208,8 +217,19 @@ def _resolve_model_source(path_or_model_name: Union[str, Path]) -> str: return str(path_or_model_name) def _get_model_init_kwargs(self) -> dict[str, Any]: - """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained.""" - return dict(self._model_init_kwargs) + """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained. + + Keep these in sync with runtime linker config so model swaps preserve + trainability settings (i.e. top-n LM layers to unfreeze). + """ + kwargs = dict(self._model_init_kwargs) + if hasattr(self.cnf_l, "use_projection_layer"): + kwargs["use_projection_layer"] = self.cnf_l.use_projection_layer + if hasattr(self.cnf_l, "top_n_layers_to_unfreeze"): + kwargs["top_n_layers_to_unfreeze"] = ( + self.cnf_l.top_n_layers_to_unfreeze + ) + return kwargs def load_transformers(self, embedding_model_name: Union[str, Path]) -> None: """Load tokenizer/model from local path or Hugging Face model id.""" @@ -224,7 +244,9 @@ def load_transformers(self, embedding_model_name: Union[str, Path]) -> None: self.cnf_l.embedding_model_name = str(embedding_model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_source) self.model = ModelForEmbeddingLinking.from_pretrained( - model_source, **model_init_kwargs + model_source, + cnf_l=self.cnf_l, + **model_init_kwargs, ) self.model.eval() self.device = torch.device( diff --git a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py index 1b9591c10..892a81a42 100644 --- a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py +++ b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py @@ -67,7 +67,8 @@ class NonTrainableEmbeddingLinkerTests(unittest.TestCase): cnf = Config() cnf.components.linking = embedding_linker.EmbeddingLinking() cnf.components.linking.comp_name = embedding_linker.Linker.name - linker = embedding_linker.Linker(FakeCDB(cnf), cnf) + vtokenizer = FakeTokenizer() + linker = embedding_linker.Linker(FakeCDB(cnf), cnf, vtokenizer) def test_linker_is_not_trainable(self): self.assertNotIsInstance(self.linker, TrainableComponent) @@ -83,7 +84,8 @@ class TrainableEmbeddingLinkerTests(unittest.TestCase): cnf.components.linking.comp_name = ( trainable_embedding_linker.Linker.name ) - linker = trainable_embedding_linker.Linker(FakeCDB(cnf), cnf) + vtokenizer = FakeTokenizer() + linker = trainable_embedding_linker.Linker(FakeCDB(cnf), cnf, vtokenizer) def test_linker_is_trainable(self): self.assertIsInstance(self.linker, TrainableComponent) diff --git a/medcat-plugins/rawstring-tokenizer/README.md b/medcat-plugins/rawstring-tokenizer/README.md new file mode 100644 index 000000000..6f2779fe3 --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/README.md @@ -0,0 +1,38 @@ +# MedCAT-gliner + +This provides [gliner](https://github.com/urchade/GLiNER) based NER step for MedCAT core library. + +# Usage + +First install from PyPI, e.g: +``` +pip install medcat-gliner +``` +Subsequently, if you have an existing model, you should be able to just change the NER component: +``` +cat = CAT.load_model_pack("path/to/existing/model") +# change component +from medcat_gliner import GLiNERConfig +cat.config.components.ner.comp_name = "gliner_ner" +cat.config.components.ner.custom_cnf = GLiNERConfig() +# recreate pipe with new NER component +cat._recreate_pipe() +# use as needed +``` + +## NER recall comparison (linkable SNOMED entities) + +The following results compare the existing NER (vocab based NER with spell checking) implementation with the gliner implementation when used as the NER component within MedCAT. +Evaluation was performed on the **2023 SNOMED CT Linking Challenge** dataset. + +> **Important caveat** +> This is **not a measure of general NER quality**. +> Recall is computed only with respect to annotated, linkable SNOMED CT entities present in the linking dataset. +> Mentions outside the annotation scope are treated as false positives by construction, so precision is not meaningful here. + +| Implementation | True Positives | False Negatives | Recall | Runtime | +| ---------------------- | -------------- | --------------- | ------ | ------- | +| Vocab based NER | 10,545 | 3,917 | 0.729 | ~5m 50s | +| GliNER implementation | 7,971 | 6,491 | 0.551 | ~34m | + +As we can see, for this dataset, GliNER is significantly slower and performs worse than the standard vocab based implementation. This is likely because the vocab based NER step has been configured and tuned to work best within the MedCAT pipeline. It is likely that with additional tuning the GliNER implementation could perform as good or better than the vocab based linker does. diff --git a/medcat-plugins/rawstring-tokenizer/pyproject.toml b/medcat-plugins/rawstring-tokenizer/pyproject.toml new file mode 100644 index 000000000..dd531a802 --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/pyproject.toml @@ -0,0 +1,113 @@ +[project] +name = "medcat-rawstring-tokenzier" + +dynamic = ["version"] + +description = "Rawstring tokenizer for MedCAT" + +readme = "README.md" + +requires-python = ">=3.10" + +license = {text = "Apache-2.0"} + +keywords = ["ML", "NLP", "NER+L"] + +authors = [ + {name = "A. Sutton"}, + {name = "M. Ratas"}, +] + +# This should be your name or the names of the organization who currently +# maintains the project, and a valid email address corresponding to the name +# listed. +maintainers = [ + {name = "CogStack", email = "contact@cogstack.org" } +] + +classifiers = [ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + "Development Status :: 5 - Production/Stable", + + "Intended Audience :: Healthcare Industry", + # "Topic :: Natural Language Processing :: Named Entity Recognition and Linking", + + # Specify the Python versions you support here. In particular, ensure + # that you indicate you support Python 3. These classifiers are *not* + # checked by "pip install". See instead "python_requires" below. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: OS Independent", +] + +# This field lists other packages that your project depends on to run. +# Any package you put here will be installed by pip when your project is +# installed, so they must be valid existing projects. +# +# For an analysis of this field vs pip's requirements files see: +# https://packaging.python.org/discussions/install-requires-vs-requirements/ +dependencies = [ + "medcat[spacy]>=2.7", +] + +# List additional groups of dependencies here (e.g. development +# dependencies). Users will be able to install these using the "extras" +# syntax, for example: +# +# $ pip install sampleproject[dev] +# +# Similar to `dependencies` above, these must be valid existing +# projects. +[project.optional-dependencies] # Optional +dev = [ + "ruff~=0.1.7", + "mypy", + "types-tqdm", + "types-setuptools", + "types-PyYAML", +] + +# entry-points to add onto medcat +[project.entry-points."medcat.plugins"] +medcat_rawstring_tokenizer = "medcat_rawstring_tokenizer" + +[project.urls] +"Homepage" = "https://cogstack.org/" +"Bug Reports" = "https://discourse.cogstack.org/" +"Source" = "https://github.com/CogStack/cogstack-nlp/tree/main/medcat-plugins/rawstring-tokenizer" + +[build-system] +# These are the assumed default build requirements from pip: +# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support +requires = ["setuptools>=43.0.0", "setuptools_scm>=8", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +"medcat_rawstring_tokenizer" = ["py.typed"] + +[tool.setuptools_scm] +# look for .git folder in root of repo +root = "../.." +version_scheme = "post-release" +local_scheme = "no-local-version" +tag_regex = "^medcat-rawstring-tokenizer/v(?P\\d+(?:\\.\\d+)*)(?:[ab]\\d+|rc\\d+)?$" +git_describe_command = "git describe --dirty --tags --long --match 'medcat-rawstring-tokenizer/v*'" + +[tool.ruff.lint] +# 1. Enable some extra checks for ruff +select = ["E", "F"] +# ignore unused local variables +ignore = ["F841"] diff --git a/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/__init__.py b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokenizer.py b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokenizer.py new file mode 100644 index 000000000..34f15d398 --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokenizer.py @@ -0,0 +1,132 @@ +from medcat.tokenizing.tokenizers import MutableDocument, MutableEntity, MutableToken +from medcat.config.config import Config +from medcat_rawstring_tokenizer.tokens import Entity, Document +from typing import Optional, Type + +class RawstringTokenizer: + """The base tokenizer protocol.""" + + def __init__(self, config: Config): + self.config = config + + def create_entity(self, doc: MutableDocument, + token_start_index: int, token_end_index: int, + label: str) -> MutableEntity: + """Create an entity from a document. + + Args: + doc (MutableDocument): The document to use. + token_start_index (int): The token start index. + token_end_index (int): The token end index. + label (str): The detected name for the entity. + + Returns: + MutableEntity: The resulting entity. + """ + # Get tokens to determine character span and text + tokens = doc[token_start_index:token_end_index] + if not tokens: + raise ValueError("No tokens in the specified range") + # Construct entity text and determine character span + text = " ".join(tkn.text for tkn in tokens) + start_char = tokens[0].char_index + end_char = tokens[-1].end_char_index + # TODO: Check this is the correct length + # maybe + 1 + text = doc.text[start_char:end_char] + # Create entity with both token and character spans + # The end index needs to be pushed forward by one + # i.e. index 9:10 means token 9 is included + # we address this in the Entity by setting end_index to be end_token.index - 1 + entity = Entity(text, + token_start_index, + token_end_index+1, + start_char, + end_char, + label) + return entity + + def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: + """Get an entity from the list of tokens. + + This will create a new instance instead of looking for existing entity. + This method should be used only if/when there was no existing entity + within the specified document for the given span of tokens. + + Args: + tokens (list[MutableToken]): List of tokens. + + Returns: + MutableEntity: The resulting entity. + """ + if not tokens: + raise ValueError("Need at least one token for an entity") + text = " ".join(tkn.text for tkn in tokens) + start_index = tokens[0].index + # The end index needs to be pushed forward by one + # i.e. index 9:10 means token 9 is included + # we address this in the Entity by setting end_index to be end_token.index - 1 + end_index = tokens[-1].index + 1 + start_char = tokens[0].char_index + end_char = tokens[-1].end_char_index + # Entity uses [start, end] char semantics, so end must stay exclusive. + return Entity(text, start_index, end_index, start_char, end_char, text) + + + def _get_existing_entity(self, tokens: list[MutableToken], + doc: MutableDocument) -> Optional[MutableEntity]: + if not tokens: + return None + for ent in doc.ner_ents + doc.linked_ents: + # The end index is exclusive + if (ent.start_index == tokens[0].base.index and + ent.end_index - 1 == tokens[-1].base.index): + return ent + return None + + def entity_from_tokens_in_doc(self, tokens: list[MutableToken], + doc: MutableDocument) -> MutableEntity: + """Get an entity from the list of tokens in the specified document. + + This method is designed to reuse entities where possible. + I don't think the document is required for this implementation. + + Args: + tokens (list[MutableToken]): List of tokens. + doc (MutableDocument): The document for these tokens. + + Returns: + MutableEntity: The resulting entity. + """ + existing_ent = self._get_existing_entity(tokens, doc) + if existing_ent: + print("Existing entity found: ", existing_ent) + return existing_ent + return self.entity_from_tokens(tokens) + + def __call__(self, text: str) -> MutableDocument: + doc = Document(text) + return doc + + @classmethod + def create_new_tokenizer(cls, config: Config) -> 'RawstringTokenizer': + return cls(config) + + def get_doc_class(self) -> Type[MutableDocument]: + """Get the document implementation class used by the tokenizer. + + This can be used (e.g) to register addon paths. + + Returns: + Type[MutableDocument]: The document class. + """ + return Document + + def get_entity_class(self) -> Type[MutableEntity]: + """Get the entity implementation class used by the tokenizer. + + Returns: + Type[MutableEntity]: The entity class. + """ + return Entity + diff --git a/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokens.py b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokens.py new file mode 100644 index 000000000..28e3a66be --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokens.py @@ -0,0 +1,253 @@ +from typing import Any, Iterator, Optional, Union, cast, overload +from bisect import bisect_right +from medcat.tokenizing.tokens import (BaseToken, MutableToken, + BaseEntity, MutableEntity, + BaseDocument, + UnregisteredDataPathException) + +import unicodedata +import re + + +# keep both hyphens and slashes within words +_WORD_RE = re.compile(r"[^\W_]+(?:[^\W_]+)*", re.UNICODE) +# _WORD_RE = re.compile(r"[^\W_]+(?:[-/][^\W_]+)*", re.UNICODE) + + +def _iter_word_spans( + text: str, + base_char_index: int = 0 + ) -> Iterator[tuple[str, int, int]]: + for match in _WORD_RE.finditer(text): + yield (match.group(0), + base_char_index + match.start(), + base_char_index + match.end()) + +class Token: + def __init__(self, + text: str, + index: int, + char_index: int, + end_char_index: int) -> None: + # --- BaseToken fields --- + self._text = text + self._index = index + self._char_index = char_index + self._end_char_index = end_char_index + # --- MutableToken fields --- + self._norm: str = text.lower() + self._to_skip: bool = False + self._is_punctuation: bool = ( + text != "" and unicodedata.category(text[0]).startswith("P") + ) + + # --- BaseToken --- + @property + def text(self) -> str: return self._text + @property + def lower(self) -> str: return self._text.lower() + @property + def text_versions(self) -> list[str]: return [self._norm, self.lower] + @property + def is_upper(self) -> bool: return self._text.isupper() + @property + def is_stop(self) -> bool: return False # handled by transformers + @property + def is_digit(self) -> bool: return self._text.isdigit() + @property + def char_index(self) -> int: return self._char_index + @property + def index(self) -> int: return self._index + @property + def end_char_index(self) -> int: return self._end_char_index + @property + def text_with_ws(self) -> str: return self._text + + # --- MutableToken --- + @property + def base(self) -> BaseToken: return cast(BaseToken, self) + @property + def is_punctuation(self) -> bool: return self._is_punctuation + @is_punctuation.setter + def is_punctuation(self, val: bool) -> None: self._is_punctuation = val + @property + def to_skip(self) -> bool: return self._to_skip + @to_skip.setter + def to_skip(self, val: bool) -> None: self._to_skip = val + @property + def lemma(self) -> str: return self._text # no lemmatization, return text as lemma + @property + def tag(self) -> Optional[str]: return None + @property + def norm(self) -> str: return self._norm + @norm.setter + def norm(self, val: str) -> None: self._norm = val + +class Entity: + _addon_extension_paths: set[str] = set() + + def __init__(self, text: str, start_index: int, end_index: int, + start_char: int, end_char: int, label: str = "") -> None: + # --- BaseEntity fields --- + # Token span is [start_index, end_index]: end is exclusive. + # Character span is [start_char, end_char]: end is exclusive. + self._text = text + self._start_index = start_index + self._end_index = end_index + self._start_char = start_char + self._end_char = end_char + self._label = label + self._addon_data: dict[str, Any] = {} + # --- MutableEntity fields --- + self.cui: str = '' + self.detected_name: str = label + self.link_candidates: list[str] = [] + self.context_similarity: float = 0.0 + self.confidence: float = 0.0 + self.id: int = -1 + + # --- BaseEntity --- + @property + def base(self) -> BaseEntity: return cast(BaseEntity, self) + @property + def text(self) -> str: return self._text + @property + def label(self) -> str: return self._label + @property + def start_index(self) -> int: return self._start_index + # This requires -1 for compatibility + @property + def end_index(self) -> int: return self._end_index - 1 + @property + def start_char_index(self) -> int: return self._start_char + @property + def end_char_index(self) -> int: return self._end_char # exclusive end index + + def __iter__(self) -> Iterator[MutableToken]: + for i, (text, char_index, end_char_index) in enumerate( + _iter_word_spans(self._text, self._start_char)): + yield Token(text, self._start_index + i, char_index, end_char_index) + + def __len__(self) -> int: return max(0, self._end_index - self._start_index) + + # --- addon data --- + def set_addon_data(self, path: str, val: Any) -> None: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + self._addon_data[path] = val + + def has_addon_data(self, path: str) -> bool: + return bool(self._addon_data.get(path)) + + def get_addon_data(self, path: str) -> Any: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + return self._addon_data.get(path) + + def get_available_addon_paths(self) -> list[str]: + return [p for p in self._addon_extension_paths if self.has_addon_data(p)] + + @classmethod + def register_addon_path(cls, path: str, def_val: Any = None, + force: bool = True) -> None: + cls._addon_extension_paths.add(path) + + +class Document: + _addon_extension_paths: set[str] = set() + + def __init__(self, text: str) -> None: + self._text = text + self._addon_data: dict[str, Any] = {} + self.ner_ents: list[MutableEntity] = [] + self.linked_ents: list[MutableEntity] = [] + self._char_indices: Optional[list[int]] = None + self._tokens: list[Token] = [ + Token(token_text, token_index, char_index, end_char_index) + for token_index, (token_text, char_index, end_char_index) in + enumerate(_iter_word_spans(text)) + ] + + @property + def base(self) -> BaseDocument: return cast(BaseDocument, self) + + @property + def text(self) -> str: return self._text + + @overload + def __getitem__(self, index: int) -> MutableToken: + pass + + @overload + def __getitem__(self, index: slice) -> list[MutableToken]: + pass + + def __getitem__(self, index: Union[int, slice] + ) -> Union[MutableToken, list[MutableToken]]: + if isinstance(index, int): + if index < 0: + index += len(self._tokens) + if index < 0 or index >= len(self._tokens): + raise IndexError("Document index out of range") + return self._tokens[index] + + start, stop, step = index.indices(len(self._tokens)) + if step != 1: + raise ValueError("Token slices must use step=1") + return self._tokens[start:stop] + + def __iter__(self) -> Iterator[MutableToken]: + return iter(self._tokens) + + def __len__(self) -> int: + return len(self._tokens) + + def isupper(self) -> bool: + return self._text.isupper() + + def _ensure_char_indices(self) -> list[int]: + if self._char_indices is None: + self._char_indices = [tkn.char_index for tkn in self._tokens] + return self._char_indices + + def get_tokens(self, start_index: int, end_index: int + ) -> list[MutableToken]: + # Keep MedCAT compatibility (inclusive end index), then resolve to + # full tokens by overlap so partial subword offsets map to words. + span_start = max(0, start_index) + span_end_exclusive = max(span_start, end_index) + 1 + + token_char_indices = self._ensure_char_indices() + lo = max(0, bisect_right(token_char_indices, span_start) - 1) + hi = min( + len(self._tokens), + bisect_right(token_char_indices, span_end_exclusive - 1) + 1 + ) + + return [ + token for token in self._tokens[lo:hi] + if token.end_char_index > span_start and + token.char_index < span_end_exclusive + ] + + + def set_addon_data(self, path: str, val: Any) -> None: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + self._addon_data[path] = val + + def has_addon_data(self, path: str) -> bool: + return bool(self._addon_data.get(path)) + + def get_addon_data(self, path: str) -> Any: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + return self._addon_data.get(path) + + def get_available_addon_paths(self) -> list[str]: + return [p for p in self._addon_extension_paths if self.has_addon_data(p)] + + @classmethod + def register_addon_path(cls, path: str, def_val: Any = None, + force: bool = True) -> None: + cls._addon_extension_paths.add(path) \ No newline at end of file diff --git a/medcat-plugins/rawstring-tokenizer/tests/__init__.py b/medcat-plugins/rawstring-tokenizer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-plugins/rawstring-tokenizer/tests/test_rawstring_tokenizer.py b/medcat-plugins/rawstring-tokenizer/tests/test_rawstring_tokenizer.py new file mode 100644 index 000000000..bdfb16bfe --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/tests/test_rawstring_tokenizer.py @@ -0,0 +1,78 @@ +from typing import runtime_checkable +from medcat.tokenizing import tokenizers +from medcat_rawstring_tokenizer.tokenizer import RawstringTokenizer +from medcat.config import Config +from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken +from medcat.utils.registry import Registry +from medcat.tokenizing.tokenizers import register_tokenizer + +import unittest + + +class RawstringTokenizerInitTests(unittest.TestCase): + default_provider = 'rawstring_tokenizer' + default_cls = RawstringTokenizer + default_creator = RawstringTokenizer.create_new_tokenizer + # spacy, regex, and now this + exp_num_def_tokenizers = 3 + + @classmethod + def setUpClass(cls): + register_tokenizer('rawstring_tokenizer', RawstringTokenizer.create_new_tokenizer) + cls.cnf = Config() + + def def_creator_name(self) -> str: + return Registry.translate_name(self.default_creator) + + def test_has_default(self): + avail_tokenizers = tokenizers.list_available_tokenizers() + self.assertEqual(len(avail_tokenizers), self.exp_num_def_tokenizers) + name, cls_name = [(t_name, t_cls) for t_name, t_cls in avail_tokenizers + if t_name == self.default_provider][0] + self.assertEqual(name, self.default_provider) + self.assertEqual(cls_name, self.def_creator_name()) + + def test_can_create_def_tokenizer(self): + tokenizer = tokenizers.create_tokenizer( + self.default_provider, self.cnf) + self.assertIsInstance(tokenizer, + runtime_checkable(tokenizers.BaseTokenizer)) + self.assertIsInstance(tokenizer, self.default_cls) + + +class TokenizerTests(unittest.TestCase): + default_provider = 'rawstring_tokenizer' + text = "Some text to tokenize" + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + + def setUp(self) -> None: + self.tokenizer = tokenizers.create_tokenizer( + self.default_provider, self.cnf) + self.doc = self.tokenizer(self.text) + self.doc.ner_ents = self._create_ner_ents(self.doc) + self.doc.linked_ents = self.doc.ner_ents.copy() + + def _create_ner_ents( + self, doc: MutableDocument, + targets: list[str] = ["text",]) -> list[MutableEntity]: + token_start = 1 + token_end = 2 + return [ + self.tokenizer.create_entity( + doc=doc, + token_start_index=token_start, + token_end_index=token_end, + label=target) + for target in targets + ] + + def test_getting_entity_based_on_tokens_gets_same_instance(self): + for ent in self.doc.ner_ents: + with self.subTest(f"Ent: {ent} in doc {self.doc}"): + tokens = list(ent) + got_ent = self.tokenizer.entity_from_tokens_in_doc(tokens, self.doc) + self.assertIs(got_ent, ent) + self.assertIn(got_ent, self.doc.ner_ents) diff --git a/medcat-plugins/transformer-ner/README.md b/medcat-plugins/transformer-ner/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-plugins/transformer-ner/pyproject.toml b/medcat-plugins/transformer-ner/pyproject.toml new file mode 100644 index 000000000..9c0ab6e1f --- /dev/null +++ b/medcat-plugins/transformer-ner/pyproject.toml @@ -0,0 +1,117 @@ +[project] +name = "medcat-transformer-ner" + +dynamic = ["version"] + +description = "Transformer based NER for MedCAT" + +readme = "README.md" + +requires-python = ">=3.10" + +license = {text = "Apache-2.0"} + +keywords = ["ML", "NLP", "NER+L"] + +authors = [ + {name = "A. Sutton"}, + {name = "T. Searle"}, + {name = "M. Ratas"}, +] + +# This should be your name or the names of the organization who currently +# maintains the project, and a valid email address corresponding to the name +# listed. +maintainers = [ + {name = "CogStack", email = "contact@cogstack.org" } +] + +classifiers = [ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + "Development Status :: 3 - Alpha", + + "Intended Audience :: Healthcare Industry", + # "Topic :: Natural Language Processing :: Named Entity Recognition and Linking", + + # Specify the Python versions you support here. In particular, ensure + # that you indicate you support Python 3. These classifiers are *not* + # checked by "pip install". See instead "python_requires" below. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: OS Independent", +] + +# This field lists other packages that your project depends on to run. +# Any package you put here will be installed by pip when your project is +# installed, so they must be valid existing projects. +# +# For an analysis of this field vs pip's requirements files see: +# https://packaging.python.org/discussions/install-requires-vs-requirements/ +dependencies = [ + "medcat[spacy]>=2.7", + "transformers>=4.41.0,<5.0", # avoid major bump + "torch>=2.4.0,<3.0", + "tqdm", +] + +# List additional groups of dependencies here (e.g. development +# dependencies). Users will be able to install these using the "extras" +# syntax, for example: +# +# $ pip install sampleproject[dev] +# +# Similar to `dependencies` above, these must be valid existing +# projects. +[project.optional-dependencies] # Optional +dev = [ + "ruff~=0.1.7", + "mypy", + "types-tqdm", + "types-setuptools", + "types-PyYAML", +] + +# entry-points to add onto medcat +[project.entry-points."medcat.plugins"] +medcat_transformer_ner = "medcat_transformer_ner" + +[project.urls] +"Homepage" = "https://cogstack.org/" +"Bug Reports" = "https://discourse.cogstack.org/" +"Source" = "https://github.com/CogStack/cogstack-nlp/tree/main/medcat-plugins/transformer-ner" + +[build-system] +# These are the assumed default build requirements from pip: +# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support +requires = ["setuptools>=43.0.0", "setuptools_scm>=8", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +"medcat_ner_transformer" = ["py.typed"] + +[tool.setuptools_scm] +# look for .git folder in root of repo +root = "../.." +version_scheme = "post-release" +local_scheme = "no-local-version" +tag_regex = "^medcat-transformer-ner/v(?P\\d+(?:\\.\\d+)*)(?:[ab]\\d+|rc\\d+)?$" +git_describe_command = "git describe --dirty --tags --long --match 'medcat-transformer-ner/v*'" + +[tool.ruff.lint] +# 1. Enable some extra checks for ruff +select = ["E", "F"] +# ignore unused local variables +ignore = ["F841"] diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py new file mode 100644 index 000000000..1f1d2b174 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py @@ -0,0 +1,3 @@ +from .registration import do_registration as __register + +__register() diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py new file mode 100644 index 000000000..d9c7cf2a0 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py @@ -0,0 +1,34 @@ +from typing import Optional, Any +from medcat.config import Ner + +class TransformerNER(Ner): + """The config exclusively used for the transformer NER""" + language_model_name: str = "michiyasunaga/BioLinkBERT-large" + """Name/path of the language model. It must be downloadable from + huggingface or linked from an appropriate file directory. NOTE: + use ner_component.load_transformers to load the model, changing this + does nothing.""" + training_batch_size: int = 32 + """The size of the batch to be used for training.""" + max_token_length: int = 512 + """Max number of tokens to be passed to the language model. + Longer sequences will be chunked""" + overlap_chunking: float = 0.1 + """How much each chunk should overlap with the previous one. + This is important to avoid missing entities that are on the border of two chunks.""" + gpu_device: Optional[Any] = None + """Choose a device for the model to be stored / computed on. If None + then an appropriate GPU device that is available will be chosen""" + require_link_candidates: bool = True + """Generate ent.link_candidates based on detected names. This requires + checking the CDB.name2info, and is required for vocab based linking. + Enabling this will lower recall, and most likely increase precision, + it will also decrease computation time.""" + use_prefix_token: bool = False + """Given a detected span, include one token previous to improve recall. + This helps with low signal words not being detected by the model. This will + increase computation time, and could reduce precision.""" + learning_rate: float = 1e-5 + """The learning rate to be used for training the model""" + weight_decay: float = 0.001 + """The weight decay to be used for training the model""" \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py new file mode 100644 index 000000000..71b42af43 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py @@ -0,0 +1,16 @@ +import logging + +from medcat.components.types import CoreComponentType +from medcat.components.types import lazy_register_core_component + + +logger = logging.getLogger(__name__) + + +def do_registration(): + lazy_register_core_component( + CoreComponentType.ner, + "transformer_ner", + "medcat_transformer_ner.transformer_ner", + "NER.create_new_component", + ) diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py new file mode 100644 index 000000000..b931615d9 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py @@ -0,0 +1,574 @@ +from pathlib import Path +from typing import Any, Optional, Union +from medcat.tokenizing.tokens import MutableDocument, MutableEntity +from medcat.components.types import CoreComponentType, TrainableComponent +from medcat.components.types import AbstractEntityProvidingComponent +from medcat.components.ner.vocab_based_annotator import annotate_name +from medcat.tokenizing.tokenizers import BaseTokenizer +from medcat.vocab import Vocab +from medcat.cdb import CDB +from medcat.config.config import ComponentConfig +from medcat.storage.serialisables import AbstractManualSerialisable +from transformers import AutoTokenizer, get_constant_schedule_with_warmup +from medcat_transformer_ner.transformer_ner_model import ModelForBinaryNER +from medcat_transformer_ner.config import TransformerNER +from torch import Tensor +import logging +import os +import torch + +logger = logging.getLogger(__name__) + + +class NER(AbstractEntityProvidingComponent, + TrainableComponent, + AbstractManualSerialisable): + name = 'transformer_ner' + + comp_name = "transformer_ner" + _MODEL_FOLDER_NAME = "transformer_ner_model" + + def __init__(self, tokenizer: BaseTokenizer, + cdb: CDB) -> None: + super().__init__() + self.tokenizer = tokenizer + self.cdb = cdb + self.config = self.cdb.config + + # NER model stuff! + self.cnf_ner: TransformerNER = self.config.components.ner + self.label2id = { + "O": 0, + "B-ENT": 1, + "I-ENT": 2, + "E-ENT": 3, + "S-ENT": 4 + } + self.id2label = {v: k for k, v in self.label2id.items()} + self._model_init_kwargs: dict[str, Any] = dict() + self.load_transformers(self.cnf_ner.language_model_name) + self.max_token_length = self.cnf_ner.max_token_length + self.overlap_chunking = self.cnf_ner.overlap_chunking + + @staticmethod + def _resolve_model_source(path_or_model_name: Union[str, Path]) -> str: + """Return local absolute path if it exists, otherwise keep HF model id.""" + candidate = Path(path_or_model_name).expanduser() + if candidate.exists(): + return str(candidate.resolve()) + return str(path_or_model_name) + + def _get_model_init_kwargs(self) -> dict[str, Any]: + """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained.""" + return dict(self._model_init_kwargs) + + def load_transformers(self, language_model_name: Union[str, Path]) -> None: + """Load tokenizer/model from local path or Hugging Face model id.""" + model_source = self._resolve_model_source(language_model_name) + model_init_kwargs = self._get_model_init_kwargs() + + if ( + not hasattr(self, "model") + or not hasattr(self, "transformer_tokenizer") + or model_source != self._loaded_model_source + or model_init_kwargs != self._loaded_model_init_kwargs + ): + self.cnf_ner.language_model_name = str(language_model_name) + + self.transformer_tokenizer = AutoTokenizer.from_pretrained( + model_source, + clean_up_tokenization_spaces=False + ) + self.model = ModelForBinaryNER( + embedding_model_name=model_source, + id2label=self.id2label, + **model_init_kwargs + ) + + self.model.eval() + self.device = torch.device( + self.cnf_ner.gpu_device + or ("cuda" if torch.cuda.is_available() else "cpu") + ) + self.model.to(self.device) + self._loaded_model_source: str = model_source + self._loaded_model_init_kwargs: dict[str, Any] = model_init_kwargs + self.optimizer = torch.optim.AdamW(self.model.parameters(), + lr=1e-5, + weight_decay=0.001) + self.scheduler = get_constant_schedule_with_warmup( + self.optimizer, + num_warmup_steps=20, + ) + logger.debug( + "Loaded embedding model: %s (resolved source: %s) with kwargs=%s " \ + "on device: %s", + language_model_name, + model_source, + model_init_kwargs, + self.device, + ) + + def get_type(self) -> CoreComponentType: + return CoreComponentType.ner + + def _chunk_and_encode(self, + text: str, + entities: Optional[list[MutableEntity]] = None + ) -> tuple[Tensor, Tensor, list[Any], list[Any], + Optional[Tensor]]: + labels_enabled = entities is not None + # First pass: tokenize full text to get offsets for chunking and label alignment + base_encoding = self.transformer_tokenizer( + text, + return_offsets_mapping=True, + add_special_tokens=False + ) + + offsets = base_encoding["offset_mapping"] + + stride = (self.max_token_length - + int(self.max_token_length * self.overlap_chunking)) + + n_tokens = len(base_encoding["input_ids"]) + start_idx = 0 + + all_input_ids = [] + all_attention_masks = [] + all_labels: list[Tensor] = [] + offset_mappings = [] + chunk_char_starts = [] + while start_idx < n_tokens: + end_idx = min(start_idx + self.max_token_length, n_tokens) + + chunk_offsets = offsets[start_idx:end_idx] + + char_start = chunk_offsets[0][0] + char_end = chunk_offsets[-1][1] + chunk_text = text[char_start:char_end] + + # Rebase entities to chunk + # iff this is a training example + if entities is not None: + chunk_entities = [] + for ent in entities: + ent_start = ent.base.start_char_index + ent_end = ent.base.end_char_index # make end exclusive + + if ent_end > char_start and ent_start < char_end: + chunk_entities.append({ + "start": ent_start - char_start, + "end": ent_end - char_start + }) + + # Tokenize chunk + encoding = self.transformer_tokenizer( + chunk_text, + return_offsets_mapping=True, + truncation=True, + padding="max_length", + max_length=self.max_token_length + ) + + offsets_chunk = encoding["offset_mapping"] + + # Label alignment to relevant chunks + if labels_enabled: + chunk_labels = [ + -100 if (start == end) else self.label2id["O"] + for start, end in offsets_chunk + ] + + + for ent in chunk_entities: + ent_token_indices = [] + for i, (token_start, token_end) in enumerate(offsets_chunk): + if token_start == token_end: + continue + if token_start < ent["end"] and token_end > ent["start"]: + ent_token_indices.append(i) + + if not ent_token_indices: + continue + + if len(ent_token_indices) == 1: + chunk_labels[ent_token_indices[0]] = self.label2id["S-ENT"] + continue + + chunk_labels[ent_token_indices[0]] = self.label2id["B-ENT"] + chunk_labels[ent_token_indices[-1]] = self.label2id["E-ENT"] + for i in ent_token_indices[1:-1]: + chunk_labels[i] = self.label2id["I-ENT"] + + all_labels.append(torch.tensor(chunk_labels, dtype=torch.long)) + + all_input_ids.append(torch.tensor(encoding["input_ids"], + dtype=torch.long)) + all_attention_masks.append(torch.tensor(encoding["attention_mask"], + dtype=torch.long)) + offset_mappings.append(offsets_chunk) + chunk_char_starts.append(char_start) + + if end_idx == n_tokens: + break + + start_idx += stride + input_ids = torch.stack(all_input_ids).to(self.device) + attention_masks = torch.stack(all_attention_masks).to(self.device) + labels = None + if labels_enabled: + labels = torch.stack(all_labels).to(self.device) + return input_ids, attention_masks, offset_mappings, chunk_char_starts, labels + + def train(self, cui: str, + entity: MutableEntity, + doc: MutableDocument, + negative: bool = False, + names: Union[list[str], dict] = []) -> None: + """Train the NER component on a given document. This is used in the + supervised training loop of the MedCAT trainer. + """ + # if this is the last entity, we'll train + # kind of a hacky work around, but it's minimal impact on the CAT trainer + if entity is doc.ner_ents[-1]: + text = doc.base.text + entities = doc.ner_ents + input_ids, attention_masks, _, _, labels = ( + self._chunk_and_encode(text, entities) + ) + self.optimizer.zero_grad() + self.model.train() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_masks, + labels=labels + ) + loss = outputs.loss + + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + 1.0 + ) + self.optimizer.step() + self.scheduler.step() + logger.debug("NER training step - loss: ", + loss.item()) + + def _decode_chunk(self, preds, offsets_chunk, chunk_char_start): + """For inference only. Decode a single chunk of predictions into entity + spans, then merge them across chunks.""" + spans = [] + current = None + for pred_id, (tok_start, tok_end) in zip(preds, offsets_chunk): + + # skip padding / special tokens + if (tok_start, tok_end) == (0, 0): + continue + + label = self.id2label[pred_id] + + # If label is "O", close any open entity span and move on. + if label == "O": + if current is not None: + spans.append(current) + current = None + continue + + # This is a bit too general for a binary ENT/ Non Ent + # But it's extendable... maybe! + prefix, ent_type = label.split("-", 1) + + abs_start = chunk_char_start + tok_start + abs_end = chunk_char_start + tok_end + + # B starts a new span + if prefix == "B": + if current is not None: + spans.append(current) + current = { + "start": abs_start, + "end": abs_end, + "label": ent_type + } + + # I continues + elif prefix == "I": + if current is not None and current["label"] == ent_type: + current["end"] = abs_end + else: + # Broken sequence -> treat as a new span. + current = { + "start": abs_start, + "end": abs_end, + "label": ent_type + } + + # E closes + elif prefix == "E": + if current is not None and current["label"] == ent_type: + current["end"] = abs_end + spans.append(current) + current = None + else: + # Broken sequence -> treat standalone E as a single-token span. + spans.append({ + "start": abs_start, + "end": abs_end, + "label": ent_type + }) + + # S is a single token span + elif prefix == "S": + if current is not None: + spans.append(current) + current = None + spans.append({ + "start": abs_start, + "end": abs_end, + "label": ent_type + }) + + if current is not None: + spans.append(current) + + return spans + + def _merge_spans(self, spans, text: str) -> list[dict]: + """Merge spans across chunk boundaries. This is required before creating + entities in the doc, otherwise we might have duplicates for the same + entity that got split across chunks. Used in inference only.""" + if not spans: + return [] + + spans = sorted(spans, key=lambda x: (x["start"], x["end"])) + merged = [spans[0]] + + for span in spans[1:]: + last = merged[-1] + gap_text = text[last["end"]:span["start"]] + gap_is_soft_separator = not ( + gap_text.strip() or gap_text.strip() in {"/", "-"} + ) + + if span["label"] == last["label"] and ( + span["start"] <= last["end"] or gap_is_soft_separator + ): + last["end"] = max(last["end"], span["end"]) + else: + merged.append(span) + + return merged + + + # Build segments in two modes: + # 1) keep half separators inside tokens, 2) split on half separators. + def _build_segments(self, + split_chars: set[str], + detected_string: str, + detected_start: int, + detected_end: int) -> list[tuple[int, int]]: + segs = [] + seg_start = None + for idx, ch in enumerate(detected_string): + if ch in split_chars: + if seg_start is not None: + segs.append((detected_start + seg_start, detected_start + idx)) + seg_start = None + elif seg_start is None: + seg_start = idx + if seg_start is not None: + segs.append((detected_start + seg_start, detected_end)) + return segs + + def _char_span_to_token_span( + self, + doc: MutableDocument, + start_char: int, + end_char: int, + ) -> Optional[tuple[int, int]]: + token_start = None + token_end = None + + for token in doc: + if token.end_char_index <= start_char: + continue + if token.char_index >= end_char: + break + + if token_start is None: + token_start = token.index + token_end = token.index + 1 + + if token_start is None or token_end is None: + return None + + return token_start, token_end + + def _span_inference(self, spans: list[dict], + doc: MutableDocument, + text: str) -> list[MutableEntity]: + ner_ents: list[MutableEntity] = [] + seen_token_spans = set() + logger.debug("Num detected spans: %s", len(spans)) + for span in spans: + detected_start = span["start"] + detected_end = span["end"] + detected_string = text[detected_start:detected_end] + if not detected_string: + continue + logger.debug( + "Detected span: [%s, %s] %r", + detected_start, + detected_end, + detected_string, + ) + + token_span = self._char_span_to_token_span(doc, + detected_start, + detected_end) + if token_span is None: + continue + + token_start, token_end = token_span + if self.cnf_ner.use_prefix_token: + token_start = token_start - 1 if token_start > 0 else token_start + # Loop through all contiguous token subspans [i:j] + for i in range(token_start, token_end): + for j in range(i + 1, token_end + 1): + span_key = (i, j) + if span_key in seen_token_spans: + continue + + sub_tokens = list(doc[i:j]) + # there might be more cleaning required here + detected_name = self.config.general.separator.join( + token.text.lower() for token in sub_tokens + ) + ent = None + if detected_name in self.cdb.name2info: + ent = annotate_name( + self.tokenizer, + detected_name, + sub_tokens, + doc, + self.cdb, + len(ner_ents), + detected_name + ) + elif not self.cnf_ner.require_link_candidates: + ent = self.tokenizer.create_entity( + doc, + i, + j, + detected_name, + ) + + if ent: + logger.debug( + "Created entity: %r tokens [%s, %s]", + ent.text, + i, + j, + ent.base.start_char_index, + ent.base.end_char_index, + ) + ner_ents.append(ent) + seen_token_spans.add(span_key) + + return ner_ents + + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + """Detect candidates for concepts - linker will then be able + to do the rest. It adds `entities` to the doc.ner_ents and each + entity can have the entity.link_candidates - that the linker + will resolve. + + Args: + doc (MutableDocument): + Spacy document to be annotated with named entities. + ents (list[MutableEntity] | None): + The entities given. This should be None. + + Returns: + list[MutableEntity]: + The NER'ed entities. + """ + # Keep offset generation in the same coordinate space as spaCy char_span. + text = doc.text + input_ids, attention_masks, offset_mappings, chunk_char_starts, _ = ( + self._chunk_and_encode(text) + ) + + self.model.eval() + with torch.no_grad(): + input_ids = input_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_masks + ) + predictions = outputs.predictions.cpu().tolist() + + all_spans = [] + for preds, offsets_chunk, char_start in zip( + predictions, + offset_mappings, + chunk_char_starts + ): + spans = self._decode_chunk(preds, offsets_chunk, char_start) + all_spans.extend(spans) + final_spans = self._merge_spans(all_spans, text) + + return self._span_inference(final_spans, doc, text) + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'TransformerNER': + return cls(tokenizer, cdb) + + def serialise_to(self, folder_path: str) -> None: + os.makedirs(folder_path, exist_ok=True) + model_folder = os.path.join(folder_path, self._MODEL_FOLDER_NAME) + os.makedirs(model_folder, exist_ok=True) + + # Save in HuggingFace format for forward compatibility. + self.model.save_pretrained(model_folder) + + @classmethod + def deserialise_from( + cls, folder_path: str, **init_kwargs + ) -> "NER": + cdb = init_kwargs["cdb"] + tokenizer = init_kwargs["tokenizer"] + ner = cls(tokenizer, cdb) + model_folder = os.path.join( + folder_path, cls._MODEL_FOLDER_NAME + ) + config_path = os.path.join(model_folder, "config.json") + weights_path = os.path.join(model_folder, "pytorch_model.bin") + if not os.path.exists(config_path) or not os.path.exists(weights_path): + raise FileNotFoundError( + "Could not find transformer-ner checkpoint files in " + f"{model_folder}. Expected both config.json and pytorch_model.bin." + ) + + # ner.model = AutoModelForTokenClassification.from_pretrained(model_folder) + ner.model = ModelForBinaryNER.from_pretrained( + model_folder, + device=ner.device, + ) + ner.optimizer = torch.optim.AdamW(ner.model.parameters(), + lr=1e-5, + weight_decay=0.001) + ner.scheduler = get_constant_schedule_with_warmup(ner.optimizer, + num_warmup_steps=20) + ner.model.to(ner.device) + ner.model.eval() + + return ner \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py new file mode 100644 index 000000000..86e71b96b --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py @@ -0,0 +1,285 @@ +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Optional, Union +from torch import Tensor, nn +from torchcrf import CRF +from transformers import AutoModelForTokenClassification +import json +import logging +import torch + +logger = logging.getLogger(__name__) + +class ModelForBinaryNER(nn.Module): + """Wrapper around a Hugging Face transformer for transformer-based NER. + + The architecture is: transformer backbone -> linear classifier -> CRF. + """ + # for mypy checking + label_is_start: Tensor + label_is_end: Tensor + + def __init__( + self, + embedding_model_name: str, + id2label: dict[int, str], + num_labels: int = 5, + top_n_layers_to_unfreeze: int = -1, + device: Optional[Union[str, torch.device]] = None, + aux_loss_weight: float = 0.5, + ) -> None: + super().__init__() + self.num_labels = num_labels + self.aux_loss_weight = aux_loss_weight + self.id2label = id2label + self.language_model = AutoModelForTokenClassification.from_pretrained( + embedding_model_name, + num_labels=self.num_labels, + ) + # Make sure hidden states are available for the auxiliary heads. + self.language_model.config.output_hidden_states = True + self.base_model_name = self.language_model.config.name_or_path + + # For the auxiliary start/end position heads, we use the hidden states + # from the last layer of the transformer. + hidden_size = self.language_model.config.hidden_size + self.start_head = nn.Linear(hidden_size, 1) + self.end_head = nn.Linear(hidden_size, 1) + + # For state transitions + self.crf = CRF(num_tags=self.num_labels, batch_first=True) + + # Build boundary lookup tables from BIOES label names. + # This is future proof in the sense that "B" and "E" would still be the same + # for multiple different types of entities. + start_flags = [] + end_flags = [] + for i in range(self.num_labels): + label = self.id2label[i] + prefix = label.split("-", 1)[0] + + start_flags.append(1.0 if prefix == "B" else 0.0) + end_flags.append(1.0 if prefix == "E" else 0.0) + + self.register_buffer("label_is_start", + torch.tensor(start_flags, dtype=torch.float)) + self.register_buffer("label_is_end", + torch.tensor(end_flags, dtype=torch.float)) + + target_device = self._resolve_device(device) + self.to(target_device) + + @staticmethod + def _resolve_device(device: Optional[Union[str, torch.device]]) -> torch.device: + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, **inputs) -> Any: + labels: Optional[Tensor] = inputs.pop("labels", None) + attention_mask: Tensor = inputs["attention_mask"] + + outputs = self.language_model(**inputs, + return_dict=True, + output_hidden_states=True) + emissions = outputs.logits + # the last layer's hidden states for the start/end heads + hidden_states = outputs.hidden_states[-1] + + # Linear classifiers for boundary heads + start_logits = self.start_head(hidden_states).squeeze(-1) # [B, T] + end_logits = self.end_head(hidden_states).squeeze(-1) # [B, T] + + # CRF can't handle -100 labels so this handles it + loss = None + crf_loss = None + start_loss = None + end_loss = None + + valid_mask = attention_mask.bool() + + if labels is not None: + labels = labels.long() + + # CRF can't handle -100 labels, so we mask them out. + crf_mask = valid_mask & (labels != -100) + safe_labels = labels.clone() + safe_labels[safe_labels == -100] = 0 + + # CRF requires the first timestep to be valid for every sequence. + crf_mask[:, 0] = True + + # CRF also requires each sequence to have at least one valid timestep. + no_valid_tokens = ~crf_mask.any(dim=1) + if no_valid_tokens.any(): + crf_mask[no_valid_tokens, 0] = True + + safe_labels[~crf_mask] = 0 + crf_loss = -self.crf( + emissions, + safe_labels, + mask=crf_mask, + reduction="token_mean", + ) + + # Auxiliary start/end targets + start_targets, end_targets = self._build_boundary_targets(labels) + + # Use the standard attention mask, but exclude -100 positions + aux_mask = valid_mask & (labels != -100) + + start_loss = self._masked_bce_loss(start_logits, start_targets, aux_mask) + end_loss = self._masked_bce_loss(end_logits, end_targets, aux_mask) + + loss = crf_loss + self.aux_loss_weight * (start_loss + end_loss) + + decoded_sequences = self.crf.decode(emissions, mask=valid_mask) + decoded_tensor = torch.zeros( + emissions.shape[:2], + dtype=torch.long, + device=emissions.device, + ) + for row_idx, seq in enumerate(decoded_sequences): + if seq: + decoded_tensor[row_idx, : len(seq)] = torch.tensor( + seq, + dtype=torch.long, + device=emissions.device, + ) + + return SimpleNamespace( + loss=loss, + crf_loss=crf_loss, + start_loss=start_loss, + end_loss=end_loss, + logits=emissions, + start_logits=start_logits, + end_logits=end_logits, + predictions=decoded_tensor, + decoded_sequences=decoded_sequences, + ) + + def _masked_bce_loss(self, logits: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + """ + Normal BCE doesn't handle masking (i.e. handling [-100]), + so this implements a masked version. + """ + loss_fn = nn.BCEWithLogitsLoss(reduction="none") + loss = loss_fn(logits, targets.float()) + loss = loss * mask.float() + + denom = mask.float().sum().clamp_min(1.0) + return loss.sum() / denom + + def _build_boundary_targets(self, labels: Tensor) -> tuple[Tensor, Tensor]: + """ + Convert BIOES token labels into binary start/end targets. + - start = 1 for B-*, (and maybe S-*) + - end = 1 for E-*, (and maybe S-*) + """ + safe_labels = labels.clone() + safe_labels[safe_labels == -100] = 0 + + start_targets = self.label_is_start[safe_labels].to(labels.device) + end_targets = self.label_is_end[safe_labels].to(labels.device) + + start_targets = start_targets.masked_fill(labels == -100, 0.0) + end_targets = end_targets.masked_fill(labels == -100, 0.0) + + return start_targets, end_targets + + def _freeze_all_parameters(self) -> None: + for param in self.language_model.parameters(): + param.requires_grad = False + + # The classification head always needs to be trainable + for param in self.language_model.classifier.parameters(): + param.requires_grad = True + + # Same for the CRF + for param in self.crf.parameters(): + param.requires_grad = True + + def unfreeze_top_n_lm_layers(self, n: int) -> None: + # train all LM layers - each layer requires more data + if n == -1: + for param in self.language_model.parameters(): + param.requires_grad = True + return + + # keep LM fully frozen - better with less data + if n == 0: + return + + base = self.language_model.base_model + # BERT-likes + if hasattr(base, "encoder") and hasattr(base.encoder, "layer"): + layers = base.encoder.layer + # DistilBERT-likes + elif hasattr(base, "transformer") and hasattr(base.transformer, "layer"): + layers = base.transformer.layer + else: + raise ValueError("Unsupported LM architecture for layer unfreezing.") + + total_layers = len(layers) + n = min(n, total_layers) + for layer in layers[-n:]: + for param in layer.parameters(): + param.requires_grad = True + + def save_pretrained(self, save_directory: Union[str, Path]) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + torch.save(self.state_dict(), save_path / "pytorch_model.bin") + + config = { + "embedding_model_name": self.base_model_name, + "num_labels": self.num_labels, + "id2label": self.id2label, + "aux_loss_weight": self.aux_loss_weight, + } + with open(save_path / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + @classmethod + def from_pretrained( + cls, + path_or_model_name: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> "ModelForBinaryNER": + path = Path(path_or_model_name) + config_path = path / "config.json" + weights_path = path / "pytorch_model.bin" + target_device = cls._resolve_device(device) + + # Local saved wrapper model. + if config_path.exists() and weights_path.exists(): + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + + # because loading in turns int keys into strings + if "id2label" in config: + id2label: dict[int, str] = {} + for key, value in config["id2label"].items(): + id2label[int(key)] = value + config["id2label"] = id2label + config.update(kwargs) + model = cls(device=target_device, **config) + state_dict = torch.load(weights_path, map_location="cpu") + model.load_state_dict(state_dict) + model.to(target_device) + return model + + # Hugging Face model id/path. + model = cls( + embedding_model_name=str(path_or_model_name), + device=target_device, + **kwargs, + ) + return model \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/tests/__init__.py b/medcat-plugins/transformer-ner/tests/__init__.py new file mode 100644 index 000000000..b40364e1c --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/__init__.py @@ -0,0 +1,26 @@ +# NOTE: mostly copied from medcat tests +import atexit +import os +import shutil + + +RESOURCES_PATH = os.path.join(os.path.dirname(__file__), "resources") +EXAMPLE_MODEL_PACK_ZIP = os.path.join(RESOURCES_PATH, "mct2_model_pack.zip") +UNPACKED_EXAMPLE_MODEL_PACK_PATH = os.path.join( + RESOURCES_PATH, "mct2_model_pack") + + +# unpack model pack at start so we can access stuff like Vocab +print("Unpacking included test model pack") +shutil.unpack_archive(EXAMPLE_MODEL_PACK_ZIP, UNPACKED_EXAMPLE_MODEL_PACK_PATH) + + +def _del_unpacked_model(): + print( + "Cleaning up! Removing unpacked exmaple model pack:", + UNPACKED_EXAMPLE_MODEL_PACK_PATH, + ) + shutil.rmtree(UNPACKED_EXAMPLE_MODEL_PACK_PATH) + + +atexit.register(_del_unpacked_model) diff --git a/medcat-plugins/transformer-ner/tests/helper.py b/medcat-plugins/transformer-ner/tests/helper.py new file mode 100644 index 000000000..4a70e259d --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/helper.py @@ -0,0 +1,84 @@ +from typing import runtime_checkable, Type, Callable + +from medcat.components import types +from medcat.config.config import Config, ComponentConfig + + +class FakeCDB: + def __init__(self, cnf: Config): + self.config = cnf + self.token_counts = {} + self.cui2info = {} + self.name2info = {} + + def weighted_average_function(self, v: int) -> float: + return v * 0.5 + + +class FVocab: + pass + + +class FTokenizer: + pass + + +class ComponentInitTests: + expected_def_components = 1 + default = "default" + # these need to be specified when overriding + comp_type: types.CoreComponentType + default_cls: Type[types.BaseComponent] + default_creator: Callable[..., types.BaseComponent] + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + cls.fcdb = FakeCDB(cls.cnf) + cls.fvocab = FVocab() + cls.vtokenizer = FTokenizer() + cls.comp_cnf: ComponentConfig = getattr(cls.cnf.components, cls.comp_type.name) + if isinstance(cls.default_creator, Type): + cls._def_creator_name_opts = (cls.default_creator.__name__,) + else: + # classmethod + cls._def_creator_name_opts = ( + ".".join( + ( + # etiher class.method_name + cls.default_creator.__self__.__name__, + cls.default_creator.__name__, + ) + ), + # or just method_name + cls.default_creator.__name__, + ) + + def test_has_default(self): + avail_components = types.get_registered_components(self.comp_type) + self.assertEqual(len(avail_components), self.expected_def_components) + name, cls_name = avail_components[0] + # 1 name / cls name + eq_name = [name == self.default for name, _ in avail_components] + eq_cls = [ + cls_name in self._def_creator_name_opts for _, cls_name in avail_components + ] + self.assertEqual(sum(eq_name), 1) + # NOTE: for NER both the default as well as the Dict based NER + # have the came class name, so may be more than 1 + self.assertGreaterEqual(sum(eq_cls), 1) + # needs to have the same class where name is equal + self.assertTrue(eq_cls[eq_name.index(True)]) + + def test_can_create_def_component(self): + component = types.create_core_component( + self.comp_type, + self.default, + self.cnf, + self.vtokenizer, + self.fcdb, + self.fvocab, + None, + ) + self.assertIsInstance(component, runtime_checkable(types.BaseComponent)) + self.assertIsInstance(component, self.default_cls) diff --git a/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip b/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip new file mode 100644 index 000000000..b6bc74e49 Binary files /dev/null and b/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip differ diff --git a/medcat-plugins/transformer-ner/tests/test_transformer_ner.py b/medcat-plugins/transformer-ner/tests/test_transformer_ner.py new file mode 100644 index 000000000..62f187df5 --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/test_transformer_ner.py @@ -0,0 +1,49 @@ +from medcat_transformer_ner import transformer_ner +from medcat.components import types +from medcat.config import Config +from medcat.vocab import Vocab +from medcat.components.types import _DEFAULT_NER as DEFAULT_NER +import unittest + +from .helper import ComponentInitTests + +class FakeDocument: + + def __init__(self, text): + self.text = text + + +class FakeTokenizer: + + def __call__(selt, text: str) -> FakeDocument: + return FakeDocument(text) + + +class FakeCDB: + + def __init__(self, config: Config): + self.config = config + + +class NerInitTests(ComponentInitTests, unittest.TestCase): + expected_def_components = len(DEFAULT_NER) + comp_type = types.CoreComponentType.ner + default = "transformer_ner" + default_cls = transformer_ner.NER + default_creator = transformer_ner.NER.create_new_component + module = transformer_ner + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + cls.cnf.components.ner = transformer_ner.TransformerNER() + cls.cnf.components.linking.comp_name = transformer_ner.NER.name + cls.fcdb = FakeCDB(cls.cnf) + cls.fvocab = Vocab() + cls.vtokenizer = FakeTokenizer() + cls.comp_cnf = getattr(cls.cnf.components, cls.comp_type.name) + + def test_has_default(self): + avail_components = types.get_registered_components(self.comp_type) + registered_names = [name for name, _ in avail_components] + self.assertIn("transformer_ner", registered_names) \ No newline at end of file