Abstract
Language models (LMs) such as BERT and GPT have revolutionized natural language processing (NLP). However, the medical field faces challenges in training LMs due to limited data access and privacy constraints imposed by regulations like the Health Insurance Portability and Accountability Act (HIPPA) and the General Data Protection Regulation (GDPR). Federated learning (FL) offers a decentralized solution that enables collaborative learning while ensuring data privacy. In this study, we evaluated FL on 2 biomedical NLP tasks encompassing 8 corpora using 6 LMs. Our results show that: 1) FL models consistently outperformed models trained on individual clients’ data and sometimes performed comparably with models trained with polled data; 2) with the fixed number of total data, FL models training with more clients produced inferior performance but pre-trained transformer-based models exhibited great resilience. 3) FL models significantly outperformed large language models using zero-/one-shot learning and offered lightning inference speed.
Introduction
The recent advances in deep learning have sparked the widespread adoption of language models (LMs), including prominent examples of BERT 1 and GPT2, in the field of natural language processing (NLP). These LMs are trained on massive amounts of public text data, comprising billions of words, and have emerged as the dominant technology for various linguistic tasks, including text classification3,4, text generation5,6, information extraction 7–9, and question answering10,11. The success of LMs can be largely attributed to their ability to leverage large volumes of training data. However, in privacy-sensitive domains like medicine, data are often naturally distributed making it difficult to construct large corpora to train LMs. To tackle the challenge, the most common approach thus far has been to fine-tune pre-trained LMs for downstream tasks, using limited annotated data12,13. Nevertheless, pre-trained LMs are typically trained on text data collected from the general domain, which exhibits divergent patterns from that in the biomedical domain, resulting in a phenomenon known as domain shift. Compared to general text, biomedical texts can be highly specialized, containing domain-specific terminologies and abbreviations14. For example, medical records and drug descriptions often include specific terms that may not be present in general language corpora, and the terms often vary among different clinical institutes. Also, biomedical data lacks uniformity and standardization across sources, making it challenging to develop NLP models that can effectively handle different formats and structures. Electronic Health Records (EHRs) from different healthcare institutions, for instance, can have varying templates and coding systems15. So, direct transfer learning from LMs pre-trained on the general domain usually suffers a drop in performance and generalizability when applied to the medical domain as is also demonstrated in the literature16. Therefore, developing LMs that are specifically designed for the medical domain, using large volumes of domain-specific training data, is essential. Another vein of research explores pre-training the LM on biomedical data, e.g., BlueBERT12, and PubMedBERT17. These LMs were either pre-trained on mixed-domain data (first pre-train on the general text and then keep pre-train on biomedical text) or directly pre-trained on domain-specific public medical datasets, e.g., PubMed literature and the Medical Information Mart for Intensive Care (MIMIC III)18 and have shown improved performances compared to classical methods such as conditional random field (CRF)19 and recurrent neural network (RNN) (e.g., long-short-term memory (LSTM)20) in many biomedical text mining tasks8,9,12,16,21. Nonetheless, it is important to highlight that the efficacy of these pre-trained medical LMs heavily relies on the availability of large volumes of task-relevant public data, which may not always be readily accessible. All these mentioned above represent the classical centralized learning regime which involves aggregating data from distributed data sites and training a model in a single environment. However, this approach poses significant challenges in medicine, where data privacy is crucial, and data access is restricted due to regulatory concerns. Thus, in practice, people can only perform training with local datasets – single-client training. The drawback comes when the local dataset is small and often gives poor performance when evaluating an external dataset – poor generalization. To take advantage of the massively distributed data as well as improve the model generalizability, federated learning (FL) was initialized in 2016 22 as a novel learning scheme to empower training with a decentralized environment and achieve many successes in critical domains with data privacy restrictions23–25. In an FL training loop, clients jointly train a shared global model by sharing the model weights or gradients while keeping their data stored locally. By bringing the model to the data, FL strictly ensures data privacy while achieving competitive levels of performance compared to a model trained with pooled data. While there is a rise of research showing great promise of applying FL in general NLP26,27, applications of FL in biomedical NLP are still under-explored. Existing works in FL on biomedical NLP are either focused on optimizing one task28,29 or trying to improve communication efficiency28. The current literature lacks a comprehensive comparison of FL on varied biomedical NLP tasks with real-world perturbations. To close this gap, we conducted an in-depth study of two representative NLP tasks, i.e., named entity recognition (NER) and relation extraction (RE), to evaluate the feasibility of adopting FL (e.g., FedAvg30 and FedProx31) with LMs (e.g., Transformer-based models) in biomedical NLP. Our study aims to provide an in-depth investigation of FL in biomedical NLP by studying several FL variants on multiple practical learning scenarios including varied federation scales, different model architectures, data heterogeneities, and comparison with large language models (LLMs) on multiple benchmark datasets. Our major findings include:
When data were independent and identically distributed (IID), models trained using FL, especially pre-trained BERT-based models, performed comparable to centralized learning, a significant boost to single-client learning. Even when data were non-IID distributed, the gap can be filled by using alternative FL algorithms.
Larger models exhibited better resistance to the changes in FL scales. With a fixed number of data, the performance of FL models overall degraded as the clients’ size increased. However, the deterioration diminished when combined with larger pre-trained models such as BERT-based models and GPT-2.
FL significantly outperformed large language models (LLMs), e.g., GPT-3, GPT-4, and PaLM 2, with zero-/one-shot learning, both in terms of prediction accuracy and inference speed.
Results
In this section, we present our main results of analysis on FL with a focus on several practical facets, including 1) learning tasks, 2) scalability, 3) data distribution, 4) model architectures and sizes, and 5) comparative assessments with LLMs.
FedAvg, Single-client, and Centralized learning for NER and RE tasks
Table 1 offers a summary of the performance evaluations for FedAvg, single-client learning, and centralized learning on five NER datasets, while Table 2 presents the results on three RE datasets. Our results on both tasks consistently demonstrate that FedAvg outperformed single-client learning. Notably, in cases involving large data volumes, such as BC4CHEMD and 2018 n2c2, FedAvg managed to attain performance levels on par with centralized learning, especially when combined with BERT-based pre-trained models.
Influence of FL scale on the performance of LMs
In clinical applications, there are two distinct learning paradigms. The first involves small-scale client cohorts, each equipped with substantial data resources, often seen in collaborations within hospital networks. In contrast, the second encompasses widely distributed clients, characterized by more limited data holders, often associated with collaborations within clinical facilities or on mobile platforms. We investigated the performance of FL on the two learning paradigms by varying client group sizes while maintaining a fixed total training data volume. The results are summarized in Fig. 1, revealing a consistent trend: notably larger models, such as those backed by BERT and GPT-2 architectures, exhibited great resilience to fluctuations in federation scales. In contrast, the lightweight model, as of BiLSMT-CRF, was susceptible to alterations of scale, resulting in a rapid deterioration in performance as the number of participating clients increased.
Performance of FL models with varying numbers of clients
Comparison of FedAvg and FedProx with data heterogeneity
Biomedical texts often exhibit high specialization due to distinct protocols employed by different hospitals when generating medical records, resulting in great variations — sublanguage differences. Therefore, FL practitioners should account for such data heterogeneity when implementing FL in healthcare systems. We simulated a real non-IID scenario by emulating BC2GM and JNLPBA as two clients and jointly performing FL. We considered two FL algorithms including FedAvg and FedProx, both are widely deployed in practice. For comparison, we also studied a simulated IID setting using the 2018 n2c2 dataset by random splitting. As shown in Table 3, we observed that the performance of FedProx was sensitive to the choice of the hyper-parameter μ. Notably, a smaller μ consistently resulted in improved performance. When μ was carefully selected, FedProx outperformed FedAvg when the data were non-IID distributed (lenient F1 score of 0.994 vs. 0.934, and strict F1 score of 0.901 vs. 0.884). However, the difference between the two algorithms was mostly indistinguishable when the data were IID distributed (lenient F1 score of 0.880 vs. 0.879, and strict F1 score of 0.820 vs. 0.818).
Impact of the LM size on the performance of different training schemes
We investigated the impact of model size on the performance of FL. We compared 6 models with varying sizes with the smallest one comprising 20 M parameters and the largest one comprising 334M parameters. We picked the BC2GM dataset for illustration and anticipated similar trends would hold for other datasets as well. As shown in Fig. 2, in most cases, larger models (represented by large circles) overall exhibited better test performance than their smaller counterparts. For example, BlueBERT demonstrated uniform enhancements in performance compared to BiLSTM-CRF and GPT2. Among all the models, BioBER emerged as the top performer, whereas GPT-2 gave the worst performance.
Comparison of model performance with different sizes, measured by the number of trainable parameters on the BC2GM dataset. The size of the circle tells the number of model parameters, while the color indicates different learning methods. The x-axis represents the mean test F1-score with the lenient match.
Comparison between FL and LLM
In light of the well-demonstrated performance of large language models (LLMs) on various linguistic tasks, we explored the performance gap of LLMs to the smaller LMs trained using FL. Notably, it is usually not common to fine-tune LLMs due to the formidable computational costs and protracted training time. Therefore, we selected two representative methods that enable direct inference from pre-trained LLMs, specifically zero-shot and one-shot learning, and compared them with models trained using FL. We followed the experimental protocol outlined in a recent study32 and evaluated all the models on two NER datasets (2018 n2c2 and NCBI-disease) and two RE datasets (2018 n2c2, and GAD). The results, as summarized in Table 4, underscore that FL, whether implemented with a BERT-based model or GPT-2 model, consistently outperformed GPT-3 and even surpassed GPT-4 and PaLM 2 with both zero-shot and one-shot learning. Beyond the performance gains, FL trained with small LMs also offered substantially faster inference speeds.
Discussion
In this study, we visited FL for biomedical NLP and studied two established tasks (NER and RE) across 7 benchmark datasets. We examined 6 LMs with varying parameter sizes (ranging from BiLSTM-CRF with 20 M to transformer-based models up to 334 M parameters) and compared their performance using centralized learning, single-client learning, and federated learning. On almost all the tasks, we showed that federated learning achieved significant improvement compared to single-client learning, and oftentimes performed comparably to centralized learning without data sharing, demonstrating it as an effective approach for privacy-preserved learning with distributed data. The only exception is in Fig. 8, where single-client learning outperformed FedAvg when using BERT and bio_ClinicalBERT. We believe this is due to the lack of training data. As each client only owned 28 training sentences, the data distribution, although IID, was highly under-represented, making it hard for FedAvg to find the global optimal solutions. Surprisingly, FL achieved reasonably good performance even when the training data was limited (284 total training sentences from all clients), confirming that transfer learning from either the general text domain (e.g., BERT and GPT-2) or biomedical text domain (e.g., blueBERT, bioBERT, bio_ClinicalBERT) is beneficial to the downstream biomedical NLP task and pretraining on medical data often gives a further boost. Another interesting finding is that GPT-2 always gave inferior results compared to BERT-based models. We believe this is because GPT-2 is pre-trained on text generation tasks that only encode left-to-right attention for the next word prediction. However, this unidirectional nature prevents it from learning more about global context which limits its ability to capture dependencies between words in a sentence.
In the sensitivity analysis of FL to client sizes, we found there is a monotonic trend that, with a fixed number of training data, FL with fewer clients tends to perform better. For example, the classical BiLSTM-CRF model (20M), with a fixed number of total training data, performs better with few clients, but performance deteriorates when more clients join in. It is likely due to the increased learning complexity as FL models need to learn the inter-correlation of data across clients. Interestingly, the transformer-based model (>= 108M), which is over 5 sizes larger compared to BILSMT-CRF, is more resilient to the change of federation scale, possibly owing to its increased learning capacity.
We analyzed the performance of FedProx in real-world non-IID scenarios and compared it with FedAvg to study the behavior of different FL algorithms under data heterogeneity. Although the FedProx achieved slightly better performance than FedAvg when the data were non-IID distributed, it is very sensitive to the hyper-parameter μ which strikes to balance the local objective function and the proximal term. Specifically, when data was IID and μ was set to a large value (e.g., μ=1), FedProx yielded a 2.4% lower lenient F1-score compared to FedAvg. When the data were non-IID, this performance gap further widened to 5.4%. It is also noteworthy that when μ is set to 0, and all the clients are forced to perform an equal number of local updates, FedProx essentially reverts to FedAvg.
We also investigated the impact of model size on the performance of FL. We observed that as the model size increased, the performance gap between centralized models and FL models narrowed. Interstingly, BioBERT, which shares the same model architecture and is similar in size to BERT and Bio_ClinicalBERT, performs comparably to larger models (such as BlueBERT), highlighting the importance of pre-training for model performance. Overall, the size of the model is indicative of its learning capacity, large models tend to perform better than smaller ones. However, large models require longer training time and more computation resources which results in a natural trade-off between accuracy and efficiency.
In comparison with LLM, FL models were the winner both in terms of prediction accuracy and inference speed. We hypothesize that LLMs, although perform well on general linguistic tasks, can not easily adapt to the specialized tasks given zero/one sample as input. To close the gap and make better use of LLMs given the context of biomedical NLP, specialized LLMs that are pre-trained on medical text data 33 or model fine-tuning 34 are needed.
While seeing many promising results of FL for LMs, we acknowledge our study suffers from the following limitations: 1) most of our experiments, excluding the non-IID study, are conducted in a simulated environment with synthetic data split, which may not perfectly align with the distribution patterns of real-world FL data. 2) we mostly focused on horizontal FL, but have not extended to vertical FL35. 3) we have not considered FL combined with privacy techniques such as differential privacy36 and homographic encryption37. To address these limitations and further advance our understanding of FL for LMs, our future study will focus on the real-world implementation of FL and explore the practical opportunities and challenges in FL such as vertical FL and FL combined privacy techniques. We believe our study will offer comprehensive insights into the potential of FL for LMs, which can serve as a catalyst for future research to develop more effective AI systems by leveraging distributed clinical data in real-world scenarios.
Methods
NLP tasks and corpora
We compared FL with alternative training schemes on 8 biomedical NLP datasets with a focus on two NLP tasks: NER (5 corpora) and RE (3 corpora). The NER and RE are two established tasks for information extraction in biomedical NLP. Given an input sequence of tokens, the goal of NER is to identify and classify the named entities, such as diseases and genes, present in the sequence. RE is often the follow-up task that aims to discover the relations between pairs of named entities. For example, a gene-disease relation (BRCA1-breast cancer) can be identified in a sentence “Mutations of BRCA1 gene are associated with breast cancer”. For RE tasks, we take the entity positions as given and formulate the problem as follows: given a sentence and the spans of two entities, the task is to determine the relationship between the two entities.
For all NER corpora, it follows the same BIO notation to distinguish the beginning (B), inside (I), and outside (O) of entities. We adopted most of the preprocessed corpora from the paper of BioBERT8, except for the 2018 n2c2 dataset (both NER and RE). For all the datasets, we removed duplicated notes and split the data into the train(80%), dev(10%), and test(10%). A summary of the datasets can be found in Table 5, we defer to supplementary materials for more detailed descriptions for each dataset.
Federated learning algorithms
FL represents a family of algorithms that aims to train models in a distributed environment in a collaborative manner. Consider a scenario where there are K clients with distributed data D = {D1, D2, …, Dk}, where Xi and Yi are the input and output space, respectively. The typical FL aims to solve the optimization problem:
where W denote the weights of the model being learned, Fi is the local objectives of i-th clients and pi is the weights of the i-th clients such that and pi > 0 and
. The weights are usually determined by the quantity of clients’ training samples. For example, it equals
when clients share the same amount of training data.
In an FL game, there are two types of players: server and client. The server is the compass that navigates the whole process of FL including signaling the start and end of federated learning, synchronizing the local model updates, and dispatching the updated models. The clients are responsible for fetching models from the server, updating models using their local data, and sending the updated models back to the server.
Throughout the whole process, there are two major steps: 1) the clients use their own data to optimize the local objectives — local updates, 2) local clients upload the updated model or gradients to the server, 3) the server acquires the local models and synchronize the updates — model aggregation, and 4) server dispatch the models to the clients. While different FL algorithms may have specialized designs for local updates or model aggregation, they share the same training paradigm.
We considered the two most popular FL algorithms called Federated Averaging (FedAvg)30 and another variant FedProx31. FedAvg is the most basic and standard FL algorithm that uses stochastic gradient descent (SGD) to progressively update the local model. More specifically, each client locally takes a fixed number of gradient descent steps on their local model using their local training data. On another hand, the server will aggregate these local models by taking the weighted average as the resulting new model for the next round. However, in FedAvg, the number of local updates can be determined by the size of the data. When the size of the data varies, the local updates performed locally can be significantly different. FedProx was introduced to tackle the issue of heterogeneous local updates in FedAvg. By adding a proximal term to the objective of the local update, the impact of variable local updates is suppressed. More specifically, at iteration t, the inner local updates are trying to find the solution that minimizes the following objective
where wt is the weights of the network from the iteration t. A comparison of FedAvg and FedProx can be found in Algorithm 1 and Algorithm 2 in supplementary materials.
Study design
As shown in Fig. 2, we explored three learning methods: 1) federated learning, centralized learning, and single-client learning. To simulate the conventional learning scenario, we varied the data scale and conducted the following experiments: centralizing all client data to train a single model (centralized learning) and training separate models on each client’s local data (single-client learning).
A comparison of centralized learning, federated learning, and single-client learning.
Models
To better understand the effect of LMs on FL, we chose models with various sizes of parameters from 20 M to 334 M including Bidirectional Encoder Representations from Transformer(BERT)1, and Generative Pre-trained Transformer (GPT), as well as classical RNN-based model like BiLSTM-CRF44. BERT-based models utilize a transformer encoder and incorporate bi-directional information acquired through two unsupervised tasks as a pre-training step into its encoder. Different BERT models differ in their pre-training source dataset and model size, deriving many variants such as BlueBERT12, BioBERT8, and Bio_ClinicBERT45. BiLSTM-CRF is the only model in our study that is not built upon transformers. It is a bi-directional model designed to handle long-term dependencies, is used to be popular for NER, and uses LSTM as its backbone. We selected this model in the interest of investigating the effect of federation learning on models with smaller sets of parameters. For LLMs, we selected GPT-3, GPT-4, and PaLM 2 for assessment as both can be publicly accessible for inference. A summary of the model can be found in Table 6, and details on the model description can be found in the supplementary materials.
Training details
Data Preprocessing
we adapted most of the dataset from the BioBERT paper with reasonable modifications by removing the duplicate entries and redoing the data splits; details of cleaning steps can be found in the supplementary materials. The maximum token limit was set at 512, with truncation—coded sentences with length larger than 512 were trimmed.
Federated learning simulation
We considered two different learning settings: learning from independent and identically distributed (IID) data and learning from non-IID data. For the first setting, we randomly split the data into k folds uniformly. For the majority of our experiments, k was chosen as 10, while we also varied k from 2 to 10 to study the impact of the size of the federation. For the second setting, we considered learning from heterogeneous data collected from different sources. This represents the real-world scenario where complex and entangled heterogeneities are co-existed. We picked the BC2GM and JNLPBA as two independent clients, both are targeting the same gene entity recognition tasks but were collected from different sources.
LLMs with zero-/one-shot learning
We followed the experiment protocol as in the previous study32. In NER, the prompts for zero-shot are designed as:
“Task: the task is to extract disease entities in a sentence”
“Input”: the input is a sentence.”
“Output: the output is an HTML that highlights all the disease entities in the sentence. The highlighting should only use HTML tags <span style=\”background-color: #FFFF00\”> and </span< and no other tags.”
For one-shot, we add an example of input and expected outputs:
“Example:
Input: In summary, inactivation of the murine ATP7B gene produces a form of cirrhotic liver disease that resembles Wilson disease in humans and toxic milk phenotype in the mouse”
Output: In summary, inactivation of the murine ATP7B gene produces a form of <span style=”background-color: #FFFF00> cirrhotic liver disease </span< that resembles <span style=”background-color: #FFFF00>Wilson disease </span< in humans and toxic milk phenotype in the mouse”
For model evaluation, we randomly selected 200 test samples in the test dataset and reported the prediction performance over the selected samples.
Training Models
We used Adam to optimize our models with an initial learning rate of 0.001 and momentum of 0.9. The learning rate was scheduled by linear_scheduler_with_warmup. All experiments were performed on a system equipped with an NVIDIA A100 GPU and an AMD EPYC 7763 64-core Processor.
Reported evaluation
For NER, we reported the performance of these metrics at the macro average level with both strict and lenient match criteria. Strict match considers the true positive when the boundary of entities exactly matches with the gold standard, while lenient considers true positives when the boundary of entities overlaps between model outputs and the gold standard. For all tasks, we repeated the experiments three times and reported the mean and standard deviation to account for randomness.
Data Availability
All the datasets involved in this study are publicly available from the following official websites:
2018 n2c2:https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/
BC2GM:https://biocreative.bioinformatics.udel.edu/tasks/
BC4CHEMD:https://biocreative.bioinformatics.udel.edu/resources/biocreative-iv/che mdner-corpus/
JNLPBA:http://www.geniaproject.org/shared-tasks/bionlp-jnlpba-shared-task-2004
NCBI-disease:https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/.
EUADR: https://biosemantics.erasmusmc.nl/index.php/resources/euadr-corpus GAD:https://maayanlab.cloud/Harmonizome/dataset/GAD+Gene-Disease+Associati ons
Code Availability
Our project codes are publicly available on Github:
Train and evaluate FL models: https://github.com/PL97/FedNLP
Texts preprocessing: https://github.com/PL97/Brat2BIO
Evaluation: https://github.com/PL97/NER_eval
LLMs evaluations: https://github.com/GaoxiangLuo/LLM-BioMed-NER-ER
Data Availability
All data produced are available online at: 2018 n2c2: https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/ BC2GM:https://biocreative.bioinformatics.udel.edu/tasks/ BC4CHEMD:https://biocreative.bioinformatics.udel.edu/resources/biocreative-iv/chemdner-corpus/JNLPBA:http://www.geniaproject.org/shared-tasks/bionlp-jnlpba-shared-task-2004 NCBI-disease:https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/ EUADR: https://biosemantics.erasmusmc.nl/index.php/resources/euadr-corpus GAD:https://maayanlab.cloud/Harmonizome/dataset/GAD+Gene-Disease+Associati ons
https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/
https://biocreative.bioinformatics.udel.edu/tasks/
https://biocreative.bioinformatics.udel.edu/resources/biocreative-iv/chemdner-corpus/
http://www.geniaproject.org/shared-tasks/bionlp-jnlpba-shared-task-2004
https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/
https://biosemantics.erasmusmc.nl/index.php/resources/euadr-corpus
https://maayanlab.cloud/Harmonizome/dataset/GAD+Gene-Disease+Associations
Author Contributions
L.P. was responsible for the overall experimental design, FL implementation, and writing of the manuscript. G.L. was responsible for the LLM prompt design, LLM experiment, evaluation, and editing of the manuscript. S.Z. and R.Z. contributed to the data collection and editing of the manuscript. J.C., Z.X, and J.S. contributed to the editing of the manuscript and idea discussion.
Acknowledgments
This work was in part supported by Cisco Research under award number 1085646 PO USA000EP390223. The authors acknowledge the Minnesota Supercomputing Institute (MSI) at the University of Minnesota for providing resources that contributed to the research results reported in this paper.
Footnotes
1 A total of 9 entities are considered including reason, frequency, ADE, strength, duration, route, form, and dosage. Details about the 2018 n2c2 dataset can be found in supplementary materials.