使用基于注意力的編碼器-解碼器實(shí)現(xiàn)醫(yī)學(xué)圖像描述
來源:DeepHub IMBA
作者:Santhosh Kurnapally
使用計算機(jī)視覺和自然語言處理來為X 射線的圖像生成文本描述。

什么是圖像描述
圖像描述是生成圖像文本描述的過程。它使用自然語言處理和計算機(jī)視覺來為圖像生成描述的文本字幕。一幅圖像可以有很多個不同的描述,但是只要它正確地描述了圖像,并且描述涵蓋了圖像中的大部分信息就可以說是沒問題的。下面是示例圖片和生成的描述文字。

放射學(xué)中的圖像描述
放射學(xué)也稱為診斷成像,是一系列通過拍攝身體部位的照片或圖像來診斷和治療疾病的測試。雖然有幾種不同的成像檢查,但最常見的包括 X 射線、MRI、超聲波、CT 掃描和 PET 掃描。
放射科醫(yī)生將查看這些成像測試的結(jié)果,找到評估和支持診斷的相關(guān)圖像?;颊咄瓿捎跋駥W(xué)檢查后,放射科醫(yī)生將向臨床醫(yī)生提供他們的解釋報告。典型的放射學(xué)報告包括以下部分:檢查名稱或類型、檢查日期、MeSH(醫(yī)學(xué)主題詞庫)、解釋放射科醫(yī)師詳細(xì)信息、臨床病史等,
借助深度學(xué)習(xí)和自然語言處理,我們可以通過描述 X 射線來減少放射科醫(yī)生的工作量,因此在本案例研究中,我們將從 X 射線中提取結(jié)果,將相同的概念擴(kuò)展到其他部分例如MeSH等,
為什么這個問題很重要?
根據(jù)美國放射學(xué)雜志和 BMJ:英國醫(yī)學(xué)雜志,與特定地區(qū)的人口相比,放射科醫(yī)生很少,特別是在農(nóng)村和較小的社區(qū)環(huán)境中,因此醫(yī)學(xué)圖像解釋和編目存在巨大延遲,從而影響到醫(yī)療診斷,并使患者護(hù)理面臨風(fēng)險。
醫(yī)學(xué)圖像由專業(yè)醫(yī)學(xué)專業(yè)人員(放射科醫(yī)師)閱讀和解釋,他們對每個檢查區(qū)域的發(fā)現(xiàn)通過書面醫(yī)學(xué)報告(放射學(xué)報告)進(jìn)行記錄和傳達(dá)。撰寫醫(yī)療報告的過程通常需要 5-10 分鐘左右。一天之內(nèi),醫(yī)生必須編寫數(shù)以百計的醫(yī)學(xué)報告,這可能會花費(fèi)他們很多時間。如果我們開發(fā)的模型可以在沒有放射科醫(yī)生和編目員的任何干預(yù)的情況下加快醫(yī)學(xué)圖像解釋和編目,這將有效地解決了這些問題。
用深度學(xué)習(xí)來解決這個問題!
圖像和文本句子是序列信息,因此我們將在編碼器-解碼器等設(shè)置中使用像 LSTM 或 GRU 這樣的 RNN(循環(huán)神經(jīng)網(wǎng)絡(luò)),并添加注意力機(jī)制來提高我們的模型性能。當(dāng)然使用Transformers 理論上來說會更好。
如何評價我的模特的表現(xiàn)呢?BLEU: Bilingual Evaluation Understudy
BLEU 是一種用于評估機(jī)器翻譯文本質(zhì)量的算法。BLEU 背后的中心思想是機(jī)器翻譯越接近專業(yè)的人工翻譯越好,它也是最早聲稱與人工質(zhì)量判斷具有高度相關(guān)性的指標(biāo)之一,并且到現(xiàn)在仍然是最受歡迎的指標(biāo)之一。
BLEU 的輸出始終是一個介于 0 和 1 之間的數(shù)字。該值表示候選文本與參考文本的相似程度,接近 1 的值表示更相似。本文使用的 BLEU 是基于n-gram 精度改進(jìn)的,因?yàn)樗褂?n-gram 來比較和評價生成文本的質(zhì)量并給出分?jǐn)?shù),它計算快速簡單并且被廣泛使用。
BLEU 的工作方式很簡單。給定一個句子和一組參考句子的一些候選翻譯,我們使用詞袋方法來查看在翻譯和參考句子中同時出現(xiàn)了多少 BOW。BOW 是一種簡單而高效的方法,可確保機(jī)器翻譯包含參考翻譯也包含的關(guān)鍵短語或單詞。換句話說,BLEU 將候選翻譯與人工生成的帶注釋的參考翻譯進(jìn)行比較,并比較候選句子中有多少命中。BOW 出現(xiàn)次數(shù)越多,翻譯效果就越好。
在了解 BLEU 之前,我們需要了解 Precision、Modified Precision 和 Brevity Penalty。
Precision:

這里 tp 和 fp 分別代表真正例和假正例。我們可以認(rèn)為正例大致對應(yīng)于命中或匹配的概念。換句話說正例是我們可以從給定的候選翻譯中構(gòu)建的單詞 n-gram 包。真正例是出現(xiàn)在候選翻譯和一些參考翻譯中的 n-gram。誤報是只出現(xiàn)在候選翻譯中的那些。
Modified Precision:
如果簡單的基于精度的度量計算會產(chǎn)生很大的問題,比如如果我們有一個候選樣本,It it it it it it it it it it it it it"上面的精度計算會給出1作為輸出,但它給定的候選是非常糟糕的。這是因?yàn)榫_度只涉及檢查是否出現(xiàn)了一個命中,但它不檢查是否重復(fù)。因此需要修改精度,如果這些重復(fù)多次,我們將進(jìn)行裁剪:

Count指的是我們分配給某個n-gram的命中次數(shù)。Mw是指在候選句子中出現(xiàn)n-gram的次數(shù)。Mmax,即該n-gram在任何一個參考句子中出現(xiàn)的最大次數(shù)。
Brevity Penalty:
Brevity Penalty懲罰短的候選翻譯,從而確保只有足夠長的機(jī)器翻譯才能獲得高分。它的目標(biāo)是找到與所的候選翻譯的長度最接近的參考句子的長度。如果該參考句子的長度大于候選句子,就會施加一些懲罰;如果候選句子更長,則不應(yīng)用任何懲罰。處罰的具體公式如下:

BLEU:
集成上面的所有 BLEU的公式如下:

這里的N為指定單詞包的大小,或N -gram,Wn表示修正后的精度pn的權(quán)重。
NLTK包中有BLEU現(xiàn)成的實(shí)現(xiàn),我們可以直接使用
from nltk.translate.bleu_score import sentence_bleu
reference = [['this', 'is', 'small', 'test']]
candidate = ['this', 'is', 'a', 'test']
print('Cumulative 1-gram: %f' % sentence_bleu(reference, candidate, weights=(1, 0, 0, 0)))
print('Cumulative 2-gram: %f' % sentence_bleu(reference, candidate, weights=(0.5, 0.5, 0, 0)))
print('Cumulative 3-gram: %f' % sentence_bleu(reference, candidate, weights=(0.33, 0.33, 0.33, 0)))
print('Cumulative 4-gram: %f' % sentence_bleu(reference, candidate, weights=(0.25, 0.25, 0.25, 0.25)))
獲取和理解和處理數(shù)據(jù)
對于這個本文的研究,我們使用來自印第安納大學(xué)醫(yī)院網(wǎng)絡(luò)的開源數(shù)據(jù)。印第安納大學(xué)-胸部x光片(PNG圖片)
https://academictorrents.com/details/5a3a439df24931f410fac269b87b050203d9467d
圖像數(shù)據(jù)的信息如下:
數(shù)據(jù)大小:1.36 GB,圖像數(shù)量:7470,所有圖片均為png格式,可以直接使用OpenCV處理圖像。所有的圖像都有相同的寬度512像素。但是高度從362 p到873 px不等。

圖像中包含了FRONTAL和LATERAL兩個方向的x光
XML報告數(shù)據(jù)如下:
印第安納大學(xué)-胸部x光片(XML報告):
https://academictorrents.com/details/66450ba52ba3f83fbf82ef9c91f2bde0e845aba9
數(shù)據(jù)大小:20.7 MB,報告總數(shù):3955,我們可以使用xml.etree.ElementTree解析XML報告,Xml包含以下重要數(shù)據(jù),需要從Xml中提取。
1、適應(yīng)癥:該數(shù)據(jù)描述了研究原因和/或適用的臨床信息或診斷的簡單、簡潔的陳述。對適應(yīng)癥的清晰理解也可以闡明研究應(yīng)解決的適當(dāng)臨床問題。例如:結(jié)核病檢測陽性、胸痛等,
2、對比:該數(shù)據(jù)描述了是否將這種新的成像檢查與任何可用的先前檢查進(jìn)行比較。比較通常涉及相同身體部位和檢查類型的檢查。
3、發(fā)現(xiàn):該數(shù)據(jù)列出了放射科醫(yī)生在檢查中身體各個部位的觀察結(jié)果。這記錄了該區(qū)域是否被認(rèn)為是正常、異常或潛在異常。例如心臟大小正常??v隔無異常。肺清凈等,
4、 結(jié)果:該數(shù)據(jù)包含調(diào)查結(jié)果的摘要,并報告他們看到的最重要的調(diào)查結(jié)果以及這些調(diào)查結(jié)果的可能原因。本節(jié)提供了最重要的決策信息。例如無急性病、清肺等,
整合上面的2個信息簡單的可視化如下:

EDA(探索性數(shù)據(jù)分析)
使用XML庫,我們從每個患者XML報告中提取“發(fā)現(xiàn)”、圖像路徑和患者id信息,并與它們形成一個數(shù)據(jù)集。
images = []
patient_ids = []
img_findings = []
for filename in tqdm(os.listdir(os.getcwd()+'/reports/ecgen-radiology')):
if filename.endswith(".xml"):
f = os.path.join(os.getcwd()+'/reports/ecgen-radiology',filename)
tree = ET.parse(f)
root = tree.getroot()
for child in root:
if child.tag == 'uId':
patient = child.attrib['id']
if child.tag == 'MedlineCitation':
for attr in child:
if attr.tag == 'Article':
for i in attr:
if i.tag == 'Abstract':
for name in i:
if name.get('Label') == 'FINDINGS':
findings=name.text
for p_image in root.findall('parentImage'):
patient_ids.append(patient)
images.append(p_image.get('id'))
img_findings.append(findings)
總共有3851名患者:
1張圖像患者:446例
2張圖像患者:3208例
3張圖像患者:181例
4張圖像患者15例
5張圖像患者:1例.

為了捕獲大部分信息,我們將兩個圖像的輸入提供給模型,規(guī)則如下
如果患者有一張與報告相關(guān)的 X 射線圖像,我們將相同的圖像復(fù)制兩次作為 image1 和 image2。
如果患者有兩張與報告相關(guān)的 X 射線圖像,我們將第一張圖像做為 image1,第二張做為 image2。
如果患者有兩個以上的 X 射線與報告相關(guān)聯(lián),我們隨機(jī)選擇 2 個 作為 image1 和 image2。
針對于“發(fā)現(xiàn)列”的數(shù)據(jù)處理

在發(fā)現(xiàn)列中大約有13%的空值。我們將刪除在結(jié)果列中具有空值的行,因?yàn)闆]法用隨機(jī)的結(jié)果填充空值。并將其轉(zhuǎn)換為小寫,刪除垃圾詞
image_findings_dataset['findings'] = image_findings_dataset.loc[:,('findings')].str.lower()
#https://stackoverflow.com/questions/19790188/expanding-english-language-contractions-in-python
def decontracted(row):
# specific
row = str(row)
row = re.sub(r"won\'t", "will not", row)
row = re.sub(r"can\'t", "can not", row)
# general
row = re.sub(r"n\'t", " not", row)
row = re.sub(r"\'re", " are", row)
row = re.sub(r"\'s", " is", row)
row = re.sub(r"\'d", " would", row)
row = re.sub(r"\'ll", " will", row)
row = re.sub(r"\'t", " not", row)
row = re.sub(r"\'ve", " have", row)
row = re.sub(r"\'m", " am", row)
row = re.sub('xxxx','',row) #occurs many times in text may be private information which isn't useful
return str(row)
def preprocessing(row):
row = str(row)
row = re.sub(r'xx*','',row) # Removing XXXX
row = re.sub(r'\d','',row) # Removing numbers
temp = ""
for i in row.split(" "): #Removing 2 letter words
if i!= 'no' or i!='ct':
temp = temp + ' ' + i
temp = re.sub(' {2,}', ' ',temp) #Replacing double space with single space
temp = re.sub(r'\.+', ".", temp) #Replacing double . with single .
temp = temp.lstrip() #Removing space at the beginning
temp = temp.rstrip() #Removing space at the end
return temp
image_findings_dataset['findings']= image_findings_dataset['findings'].apply(preprocessing)
理解統(tǒng)計結(jié)果



我們可以看到像胸腔積液(pleural effusion),氣胸(pneumothora),心臟縱隔輪廓( cardiomediastnal silhouette),yi一般情況下我們認(rèn)為這些詞不是正常詞,但這些是醫(yī)學(xué)領(lǐng)域特有的,并且出現(xiàn)的頻率很大,說明預(yù)處理后看起來很干凈。
數(shù)據(jù)拆分和標(biāo)記

如果仔細(xì)觀察結(jié)果列,可以看到結(jié)果列中的數(shù)據(jù)偏向于非疾病數(shù)據(jù)(數(shù)據(jù)不平衡),并且由于我們的數(shù)據(jù)非常少,大約 3300 條記錄,這根本不足以用于深度學(xué)習(xí)方法,所以這里將嘗試使用重新采樣的方法處理數(shù)據(jù)使數(shù)據(jù)平衡(我們嘗試了多種方法,下面的方法是最好的)
在結(jié)果列中重復(fù)了很多數(shù)據(jù),讓我們采用一種策略來訓(xùn)練更好的模型。
第 1 步:讓我們把數(shù)據(jù)集分成兩部分
1、發(fā)現(xiàn)列出現(xiàn)次數(shù)超過 25 次。
2、發(fā)現(xiàn)列少于或等于 5 次。
第 2 步:用 test_size = 0.1 劃分訓(xùn)練測試集以獲得大于 5 的結(jié)果。
第 3 步:將 20% 樣本大小的訓(xùn)練測試集劃分為小于或等于 5 的結(jié)果。然后添加該樣本測試并使用剩下的進(jìn)行訓(xùn)練
第 4 步:上采樣少數(shù)點(diǎn),下采樣多數(shù)點(diǎn)
通過這樣做,可以減少數(shù)據(jù)集中在發(fā)現(xiàn)方面的不平衡
findings_gt_5 = image_findings_dataset[image_findings_dataset['findings_count']>5]
findings_lte_5 = image_findings_dataset[image_findings_dataset['findings_count']<=5]
train,test = train_test_split(findings_gt_5,stratify = findings_gt_5['findings'].values,test_size = 0.1,random_state = 420)
test_findings_lte_5_sample = findings_lte_5.sample(int(0.2*findings_lte_5.shape[0]),random_state = 420)
findings_lte_5 = findings_lte_5.drop(test_findings_lte_5_sample.index,axis=0)
test = test.append(test_findings_lte_5_sample)
test = test.reset_index(drop=True)
train = train.append(findings_lte_5)
train = train.reset_index(drop=True)
train.shape[0],test.shape[0]
image_findings_dataset_majority = train[train['findings_count']>=25] #having value counts >=25
image_findings_dataset_minority = train[train['findings_count']<=5] #having value counts <=5
image_findings_dataset_other = train[(train['findings_count']>5)&(train['findings_count']<25)] #value counts between 5 and 25
n1 = image_findings_dataset_minority.shape[0]
n2 = image_findings_dataset_majority.shape[0]
n3 = image_findings_dataset_other.shape[0]
image_findings_dataset_minority_upsampled = resample(image_findings_dataset_minority,
replace = True,
n_samples = 4*n1,
random_state = 420)
image_findings_dataset_majority_downsampled = resample(image_findings_dataset_majority,
replace = False,
n_samples = n2//5,
random_state = 420)
image_findings_dataset_other_downsampled = resample(image_findings_dataset_other,
replace = False,
n_samples = n3//3,
random_state = 420)
train = pd.concat([image_findings_dataset_majority_downsampled ,image_findings_dataset_minority_upsampled,image_findings_dataset_other_downsampled])
train = train.reset_index(drop=True)
train.shape
在分別對少數(shù)和多數(shù)樣本進(jìn)行上采樣和下采樣后,得到的訓(xùn)練數(shù)據(jù)有 8795 條記錄,測試數(shù)據(jù)有 604 條記錄,在醫(yī)學(xué)這種不平衡數(shù)據(jù)中這個過程是必須的的。
創(chuàng)建令牌標(biāo)記
tokenizer = Tokenizer(filters = '',oov_token = '<unk>') #setting filters to none
tokenizer.fit_on_texts(train.findings_total.values)
train_captions = tokenizer.texts_to_sequences(train.findings_total)
test_captions = tokenizer.texts_to_sequences(test.findings_total)
vocab_size = len(tokenizer.word_index)
caption_len = np.array([len(i) for i in train_captions])
start_index = tokenizer.word_index['<start>'] #tokened value of <start>
end_index = tokenizer.word_index['<end>'] #tokened value of <end>
現(xiàn)在數(shù)據(jù)集已準(zhǔn)備好進(jìn)行建模了

構(gòu)建圖像描述模型
在建立模型之前,讓我們先了解一些注意力在基于的編碼器-解碼器模型中使用的概念。
ChexNet
ChexNet 是一種深度學(xué)習(xí)算法,可以從胸部 X 光圖像中檢測和定位 14 種疾病。在 ChestX-ray14 數(shù)據(jù)集上訓(xùn)練了一個 121 層的卷積神經(jīng)網(wǎng)絡(luò),該數(shù)據(jù)集包含來自 30,805 名獨(dú)特患者的 112,120 張正面視圖 X 射線圖像。結(jié)果非常好超過了執(zhí)業(yè)放射科醫(yī)生的表現(xiàn)。
我們使用 ChexNet 預(yù)訓(xùn)練的權(quán)重來使用遷移學(xué)習(xí)獲得 X 射線的嵌入。由于 ChexNet 權(quán)重在 ChestX-ray14 數(shù)據(jù)集上的疾病分類等任務(wù)中得到了很好的收斂。
論文:https://arxiv.org/pdf/1711.05225v3.pdf
權(quán)重文件:https://www.kaggle.com/datasets/theewok/chexnet-keras-weights
ChexNet 使用與主干類似的架構(gòu)是 DenseNet121,下面是 DenseNet 架構(gòu)。

GloVe
GloVe 是一種用于獲取單詞向量表示的無監(jiān)督學(xué)習(xí)算法。對來自語料庫的聚合全局詞-詞共現(xiàn)統(tǒng)計進(jìn)行訓(xùn)練,得到的表示展示了詞向量空間的線性子結(jié)構(gòu)。
GloVe 本質(zhì)上是一個具有加權(quán)最小二乘目標(biāo)的對數(shù)雙線性模型。該模型的主要理論是簡單的觀察,即單詞-單詞共現(xiàn)概率的比率有可能編碼某種形式的含義。
我們使用預(yù)訓(xùn)練的詞向量將詞轉(zhuǎn)換為嵌入,GloVe 提供多維重新訓(xùn)練的詞向量,其中我們使用 300 維的詞向量進(jìn)行詞嵌入轉(zhuǎn)換。
資料來源:https://nlp.stanford.edu/projects/glove/
Glove300d.zip:https://nlp.stanford.edu/data/glove.6B.zip
LSTM
簡單的RNN不能很好地處理長期依賴關(guān)系。lstm被明確設(shè)計為避免長期依賴問題。

lstm有三個輸入和兩個輸出,能夠向單元狀態(tài)中刪除或添加信息,也可以不加修改地傳遞信息。
注意力機(jī)制
注意力模型也稱為注意力機(jī)制,是一種深度學(xué)習(xí)技術(shù),用于提供對特定組件的額外關(guān)注。注意力模型的目的是將更大、更復(fù)雜的任務(wù)簡化為更小、更易于管理的注意力區(qū)域,以便按順序理解和處理。
注意力模型的最初目的是幫助改善計算機(jī)視覺和基于編碼器-解碼器的神經(jīng)機(jī)器翻譯系統(tǒng)。該系統(tǒng)使用自然語言處理 (NLP) 并依賴于具有復(fù)雜功能的龐大數(shù)據(jù)庫。使用注意力模型有助于創(chuàng)建固定長度向量的映射以生成翻譯和理解。
注意力模型可以簡單的分為3類:
自注意力模型
全局注意力模型
局部注意力模型
本文中我們將使用 Bahdanau 和 Loung 建議的論文中使用全局注意力模型(Global Attention Model)。
該模型基于與源位置和先前生成的目標(biāo)詞相關(guān)聯(lián)的上下文向量來預(yù)測目標(biāo)詞。具有注意機(jī)制的Seq2Seq模型由編碼器、解碼器和注意層組成。

模型編碼實(shí)現(xiàn)
通過加載和下載的權(quán)重來實(shí)現(xiàn)ChexNet,為了進(jìn)行微調(diào)將ChexNet模型的可訓(xùn)練參數(shù)設(shè)置為false,因?yàn)槲覀兿M看味际褂孟嗤臋?quán)重,并且不想在反向傳播中更新這些權(quán)重。
def create_chexnet(chexnet_weights = chexnet_weights,input_size = input_size):
"""
chexnet_weights: weights value in .h5 format of chexnet
creates a chexnet model with preloaded weights present in chexnet_weights file
"""
model = tf.keras.applications.DenseNet121(include_top=False,input_shape = input_size+(3,)) #importing densenet the last layer will be a relu activation layer
#we need to load the weights so setting the architecture of the model as same as the one of the chexnet
x = model.output #output from chexnet
x = GlobalAveragePooling2D()(x)
x = Dense(14, activation="sigmoid", name="chexnet_output")(x) #here activation is sigmoid as seen in research paper
chexnet = tf.keras.Model(inputs = model.input,outputs = x)
chexnet.load_weights(chexnet_weights)
chexnet = tf.keras.Model(inputs = model.input,outputs = chexnet.layers[-3].output) #we will be taking the 3rd last layer (here it is layer before global avgpooling)
#since we are using attention here
return chexnet
下載并使用了300維預(yù)訓(xùn)練的GloVe向量。
glove = {}
with open('/content/drive/MyDrive/Project_on_Drive/glove/glove.6B.300d.txt',encoding='utf-8') as f: #taking 300 dimesions
for line in tqdm(f):
word = line.split() #it is stored as string like this "'the': '.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.4"
glove[word[0]] = np.asarray(word[1:], dtype='float32')
embedding_dim = 300
# create a weight matrix for words in training docs for embedding purpose
embedding_matrix = np.zeros((vocab_size+1, embedding_dim)) #https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
for word, i in tqdm(tokenizer.word_index.items()):
embedding_vector = glove.get(word)
if embedding_vector is not None: #if the word is found in glove vectors
embedding_matrix[i] = embedding_vector[:embedding_dim]
創(chuàng)建數(shù)據(jù)處理管道它將在圖像和文本上執(zhí)行任務(wù),并使它們準(zhǔn)備好被模型使用。
將圖像調(diào)整為255x255像素。
將文本結(jié)果向量化,并將所有結(jié)果填充到相同的長度。
這里使用的圖像增強(qiáng)技術(shù)是在水平方向和垂直方向以均勻概率翻轉(zhuǎn)圖像。如果概率小于33%,則水平翻轉(zhuǎn),如果介于33和66%之間,則垂直翻轉(zhuǎn),否則不進(jìn)行圖像增強(qiáng)。
還進(jìn)行了數(shù)據(jù)打亂的操作
class Dataset():
#here we will get the images converted to vector form and the corresponding captions
def __init__(self,df,input_size,tokenizer = tokenizer, augmentation = True,max_pad = max_pad):
"""
df = dataframe containing image_1,image_2 and findings
"""
self.image1 = df.image1
self.image2 = df.image2
self.caption = df.decoder_ip #inp
self.caption1 = df.decoder_op #output
self.input_size = input_size #tuple ex: (512,512)
self.tokenizer = tokenizer
self.augmentation = augmentation
self.max_pad = max_pad
#image augmentation
#https://imgaug.readthedocs.io/en/latest/source/overview/flip.html?highlight=Fliplr
self.aug1 = iaa.Fliplr(1) #flip images horizaontally
self.aug2 = iaa.Flipud(1) #flip images vertically
# https://imgaug.readthedocs.io/en/latest/source/overview/convolutional.html?highlight=emboss#emboss
# self.aug3 = iaa.Emboss(alpha=(1), strength=1) #embosses image
# #https://imgaug.readthedocs.io/en/latest/source/api_augmenters_convolutional.html?highlight=sharpen#imgaug.augmenters.convolutional.Sharpen
# self.aug4 = iaa.Sharpen(alpha=(1.0), lightness=(1.5)) #sharpens the image and apply some lightness/brighteness 1 means fully sharpened etc
def __getitem__(self,i):
#gets the datapoint at i th index, we will extract the feature vectors of images after resizing the image and apply augmentation
image1 = cv2.imread(self.image1[i],cv2.IMREAD_UNCHANGED)/255
image2 = cv2.imread(self.image2[i],cv2.IMREAD_UNCHANGED)/255 #here there are 3 channels
image1 = cv2.resize(image1,self.input_size,interpolation = cv2.INTER_NEAREST)
image2 = cv2.resize(image2,self.input_size,interpolation = cv2.INTER_NEAREST)
if image1.any()==None:
print("%i , %s image sent null value"%(i,self.image1[i]))
if image2.any()==None:
print("%i , %s image sent null value"%(i,self.image2[i]))
#tokenizing and padding
caption = self.tokenizer.texts_to_sequences(self.caption[i:i+1]) #the input should be an array for tokenizer ie [self.caption[i]]
caption = pad_sequences(caption,maxlen = self.max_pad,padding = 'post') #opshape:(input_length,)
caption = tf.squeeze(caption,axis=0) #opshape = (input_length,) removing unwanted axis if present
caption1 = self.tokenizer.texts_to_sequences(self.caption1[i:i+1]) #the input should be an array for tokenizer ie [self.caption[i]]
caption1 = pad_sequences(caption1,maxlen = self.max_pad,padding = 'post') #opshape: (input_length,)
caption1 = tf.squeeze(caption1,axis=0) #opshape = (input_length,) removing unwanted axis if present
if self.augmentation: #we will not apply augmentation that crops the image
a = np.random.uniform()
if a<0.333:
image1 = self.aug1.augment_image(image1)
image2 = self.aug1.augment_image(image2)
elif a<0.667:
image1 = self.aug2.augment_image(image1)
image2 = self.aug2.augment_image(image2)
else: #applying no augmentation
pass;
return image1,image2,caption,caption1
def __len__(self):
return len(self.image1)
class Dataloader(tf.keras.utils.Sequence): #for batching
def __init__(self, dataset, batch_size=1, shuffle=True):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.indexes = np.arange(len(self.dataset))
def __getitem__(self, i):
# collect batch data
start = i * self.batch_size
stop = (i + 1) * self.batch_size
indexes = [self.indexes[j] for j in range(start,stop)] #getting the shuffled index values
data = [self.dataset[j] for j in indexes] #taken from Data class (calls __getitem__ of Data) here the shape is batch_size*3, (image_1,image_2,caption)
batch = [np.stack(samples, axis=0) for samples in zip(*data)] #here the shape will become batch_size*input_size(of image)*3,batch_size*input_size(of image)*3
#,batch_size*1*max_pad
return tuple([[batch[0],batch[1],batch[2]],batch[3]]) #here [image1,image2, caption(without <END>)],caption(without <CLS>) (op)
def __len__(self): #returns total number of batches in an epoch
return len(self.indexes) // self.batch_size
def on_batch_end(self): #it runs at the end of epoch
if self.shuffle:
np.random.shuffle(self.indexes) #in-place shuffling takes place
編碼器層使用 ChexNet 權(quán)重對輸入 X 射線進(jìn)行編碼,
class Image_encoder(tf.keras.layers.Layer):
"""
This layer will output image backbone features after passing it through chexnet
here chexnet will be not be trainable
"""
def __init__(self,
name = "image_encoder_block"
):
super().__init__()
self.chexnet = create_chexnet()
self.chexnet.trainable = False
self.avgpool = AveragePooling2D()
def call(self,data):
op = self.chexnet(data) #op shape: (None,7,7,1024)
op = self.avgpool(op) #op shape (None,3,3,1024)
op = tf.reshape(op,shape = (-1,op.shape[1]*op.shape[2],op.shape[3])) #op shape: (None,9,1024)
return op
def encoder(image1,image2,dense_dim = dense_dim,dropout_rate = dropout_rate):
"""
Takes image1,image2
gets the final encoded vector of these
"""
#image1
im_encoder = Image_encoder()
bkfeat1 = im_encoder(image1) #shape: (None,9,1024)
bk_dense = Dense(dense_dim,name = 'bkdense',activation = 'relu') #shape: (None,9,512)
bkfeat1 = bk_dense(bkfeat1)
#image2
bkfeat2 = im_encoder(image2) #shape: (None,9,1024)
bkfeat2 = bk_dense(bkfeat2) #shape: (None,9,512)
#combining image1 and image2
concat = Concatenate(axis=1)([bkfeat1,bkfeat2]) #concatenating through the second axis shape: (None,18,1024)
bn = BatchNormalization(name = "encoder_batch_norm")(concat)
dropout = Dropout(dropout_rate,name = "encoder_dropout")(bn)
return dropout
注意力層:
class global_attention(tf.keras.layers.Layer):
"""
calculate global attention
"""
def __init__(self,dense_dim = dense_dim):
super().__init__()
# Intialize variables needed for Concat score function here
self.W1 = Dense(units = dense_dim) #weight matrix of shape enc_units*dense_dim
self.W2 = Dense(units = dense_dim) #weight matrix of shape dec_units*dense_dim
self.V = Dense(units = 1) #weight matrix of shape dense_dim*1
#op (None,98,1)
def call(self,encoder_output,decoder_h): #here the encoded output will be the concatted image bk features shape: (None,98,dense_dim)
decoder_h = tf.expand_dims(decoder_h,axis=1) #shape: (None,1,dense_dim)
tanh_input = self.W1(encoder_output) + self.W2(decoder_h) #ouput_shape: batch_size*98*dense_dim
tanh_output = tf.nn.tanh(tanh_input)
attention_weights = tf.nn.softmax(self.V(tanh_output),axis=1) #shape= batch_size*98*1 getting attention alphas
op = attention_weights*encoder_output#op_shape: batch_size*98*dense_dim multiply all aplhas with corresponding context vector
context_vector = tf.reduce_sum(op,axis=1) #summing all context vector over the time period ie input length, output_shape: batch_size*dense_dim
return context_vector,attention_weights
單步解碼器:
將input_to_decoder傳遞給嵌入層,然后獲得輸出(batch_size,1, embedding_dim)
使用encoder_output和解碼器隱藏狀態(tài),計算上下文向量。
連接上下文向量與步驟A輸出
將Step-C輸出傳遞給LSTM/GRU,并獲得解碼器輸出和狀態(tài)(隱藏和單元狀態(tài))
將解碼器輸出傳遞到致密層(詞匯表大小),并將結(jié)果存儲到輸出中。
返回Step-D的狀態(tài),Step-E的輸出,Step-B的注意權(quán)重
class One_Step_Decoder(tf.keras.layers.Layer):
"""
decodes a single token
"""
def __init__(self,vocab_size = vocab_size, embedding_dim = embedding_dim, max_pad = max_pad, dense_dim = dense_dim ,name = "onestepdecoder"):
# Initialize decoder embedding layer, LSTM and any other objects needed
super().__init__()
self.dense_dim = dense_dim
self.embedding = Embedding(input_dim = vocab_size+1,
output_dim = embedding_dim,
input_length=max_pad,
weights = [embedding_matrix],
mask_zero=True,
name = 'onestepdecoder_embedding'
)
self.LSTM = GRU(units=self.dense_dim,return_sequences=True,return_state=True,name = 'onestepdecoder_LSTM')
self.LSTM1 = GRU(units=self.dense_dim,return_sequences=False,return_state=True,name = 'onestepdecoder_LSTM1')
self.attention = global_attention(dense_dim = dense_dim)
self.concat = Concatenate(axis=-1)
self.dense = Dense(dense_dim,name = 'onestepdecoder_embedding_dense',activation = 'relu')
self.final = Dense(vocab_size+1,activation='softmax')
self.concat = Concatenate(axis=-1)
self.add =Add()
@tf.function
def call(self,input_to_decoder, encoder_output, decoder_h):#,decoder_c):
'''
One step decoder mechanisim step by step:
A. Pass the input_to_decoder to the embedding layer and then get the output(batch_size,1,embedding_dim)
B. Using the encoder_output and decoder hidden state, compute the context vector.
C. Concat the context vector with the step A output
D. Pass the Step-C output to LSTM/GRU and get the decoder output and states(hidden and cell state)
E. Pass the decoder output to dense layer(vocab size) and store the result into output.
F. Return the states from step D, output from Step E, attention weights from Step -B
here state_h,state_c are decoder states
'''
embedding_op = self.embedding(input_to_decoder) #output shape = batch_size*1*embedding_shape (only 1 token)
context_vector,attention_weights = self.attention(encoder_output,decoder_h) #passing hidden state h of decoder and encoder output
#context_vector shape: batch_size*dense_dim we need to add time dimension
context_vector_time_axis = tf.expand_dims(context_vector,axis=1)
#now we will combine attention output context vector with next word input to the lstm here we will be teacher forcing
concat_input = self.concat([context_vector_time_axis,embedding_op])#output dimension = batch_size*input_length(here it is 1)*(dense_dim+embedding_dim)
output,decoder_h = self.LSTM(concat_input,initial_state = decoder_h)
output,decoder_h = self.LSTM1(output,initial_state = decoder_h)
#output shape = batch*1*dense_dim and decoder_h,decoder_c has shape = batch*dense_dim
#we need to remove the time axis from this decoder_output
output = self.final(output)#shape = batch_size*decoder vocab size
return output,decoder_h,attention_weights
解碼器層負(fù)責(zé)解碼編碼器輸出和標(biāo)題。解碼器迭代所有的時間步,直到最大填充值,并一個一個地生成每個單詞。
class decoder(tf.keras.Model):
"""
Decodes the encoder output and caption
"""
def __init__(self,max_pad = max_pad, embedding_dim = embedding_dim,dense_dim = dense_dim,score_fun='general',batch_size = batch_size,vocab_size = vocab_size):
super().__init__()
self.onestepdecoder = One_Step_Decoder(vocab_size = vocab_size, embedding_dim = embedding_dim, max_pad = max_pad, dense_dim = dense_dim)
self.output_array = tf.TensorArray(tf.float32,size=max_pad)
self.max_pad = max_pad
self.batch_size = batch_size
self.dense_dim =dense_dim
@tf.function
def call(self,encoder_output,caption):#,decoder_h,decoder_c): #caption : (None,max_pad), encoder_output: (None,dense_dim)
decoder_h, decoder_c = tf.zeros_like(encoder_output[:,0]), tf.zeros_like(encoder_output[:,0]) #decoder_h, decoder_c
output_array = tf.TensorArray(tf.float32,size=max_pad)
for timestep in range(self.max_pad): #iterating through all timesteps ie through max_pad
output,decoder_h,attention_weights = self.onestepdecoder(caption[:,timestep:timestep+1], encoder_output, decoder_h)
output_array = output_array.write(timestep,output) #timestep*batch_size*vocab_size
self.output_array = tf.transpose(output_array.stack(),[1,0,2]) #.stack :Return the values in the TensorArray as a stacked Tensor.)
#shape output_array: (batch_size,max_pad,vocab_size)
return self.output_array
這里我們還將一些比較常見的訓(xùn)練技巧加入到了訓(xùn)練中,例如早停機(jī)制,學(xué)習(xí)率計劃和使用tensorboard展示
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
def custom_loss(y_true, y_pred):
#getting mask value to not consider those words which are not present in the true caption
mask = tf.math.logical_not(tf.math.equal(y_true, 0))
#y_pred = y_pred+10**-7 #to prevent loss becoming null
#calculating the loss
loss_ = loss_func(y_true, y_pred)
#converting mask dtype to loss_ dtype
mask = tf.cast(mask, dtype=loss_.dtype)
#applying the mask to loss
loss_ = loss_*mask
#returning mean over all the values
return tf.reduce_mean(loss_)
tf.keras.backend.clear_session()
tb_filename = 'Encoder_Decoder_global_attention/'
tb_file = os.path.join('/content/drive/MyDrive/Project_on_Drive',tb_filename)
model_filename = 'Encoder_Decoder_global_attention.h5'
model_save = os.path.join('/content/drive/MyDrive/Project_on_Drive',model_filename)
my_callbacks = [
tf.keras.callbacks.EarlyStopping(patience = 5,
verbose = 2
),
tf.keras.callbacks.ModelCheckpoint(filepath=model_save,
save_best_only = True,
save_weights_only = True,
verbose = 2
),
tf.keras.callbacks.TensorBoard(histogram_freq=1,
log_dir=tb_file),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,
patience=2, min_lr=10**-7, verbose = 2)
] #from keras documentation
我們的模型結(jié)構(gòu)如下:


訓(xùn)練的參數(shù)如下:
batch_size = 100
embedding_dim = 300
dense_dim = 512
lstm_units = dense_dim
dropout_rate = 0.2
lr (Learning Rate) = 10**-2
number of epochs = 10
min_lr (Minimum Learning rate) =10**-7
模型訓(xùn)練了10輪,可以看到損失為0.5577,精度為0.8466,驗(yàn)證損失和精度分別為1.4386和0.6907,如果我們繼續(xù)運(yùn)行模型,可以得到更好的損失和精度,但看起來模型是過擬合的,因?yàn)榈玫搅?0輪的最佳結(jié)果。



可視化可以看到,評估精度隨著迭代次數(shù)的增加而增加,評估損失隨著迭代次數(shù)的增加而減少,這是一個很好的跡象,表明權(quán)重正在收斂,所有導(dǎo)數(shù)都在良好的范圍內(nèi),沒有爆炸或消失的梯度。
使用Greedy Search測試標(biāo)題預(yù)測和BLEU評分
我們決定使用Greedy Search:是因?yàn)檎陬A(yù)測文本,并且希望在每個單詞之后預(yù)測下一個最佳單詞的概率,并且Greedy Search的計算成本并不高,因?yàn)槲覀儑L試了一些啟發(fā)式搜索算法,例如beam search,結(jié)果證明它們的計算成本很高。
def greedy_search_predict(image1,image2,model = model1):
"""
Given paths to two x-ray images predicts the findings part of the x-ray in a greedy search algorithm
"""
image1 = cv2.imread(image1,cv2.IMREAD_UNCHANGED)/255
image2 = cv2.imread(image2,cv2.IMREAD_UNCHANGED)/255
image1 = tf.expand_dims(cv2.resize(image1,input_size,interpolation = cv2.INTER_NEAREST),axis=0) #introduce batch and resize
image2 = tf.expand_dims(cv2.resize(image2,input_size,interpolation = cv2.INTER_NEAREST),axis=0)
image1 = model.get_layer('image_encoder')(image1)
image2 = model.get_layer('image_encoder')(image2)
image1 = model.get_layer('bkdense')(image1)
image2 = model.get_layer('bkdense')(image2)
concat = model.get_layer('concatenate')([image1,image2])
enc_op = model.get_layer('encoder_batch_norm')(concat)
enc_op = model.get_layer('encoder_dropout')(enc_op) #this is the output from encoder
decoder_h,decoder_c = tf.zeros_like(enc_op[:,0]),tf.zeros_like(enc_op[:,0])
a = []
pred = []
for i in range(max_pad):
if i==0: #if first word
caption = np.array(tokenizer.texts_to_sequences(['<start>'])) #shape: (1,1)
output,decoder_h,attention_weights = model.get_layer('decoder').onestepdecoder(caption,enc_op,decoder_h)#,decoder_c) decoder_c,
#prediction
max_prob = tf.argmax(output,axis=-1) #tf.Tensor of shape = (1,1)
caption = np.array([max_prob]) #will be sent to onstepdecoder for next iteration
if max_prob==np.squeeze(tokenizer.texts_to_sequences(['<end>'])):
break;
else:
a.append(tf.squeeze(max_prob).numpy())
return tokenizer.sequences_to_texts([a])[0]

為什么只有28.3%的BLEU得分。這時深度學(xué)習(xí)需要大量的數(shù)據(jù),但我們提供給模型的數(shù)據(jù)非常少,即使在大量重采樣之后,也會偏向于非疾病數(shù)據(jù),因此這個BLEU評分對于我們使用的數(shù)據(jù)來說已經(jīng)很好了,如果我們有大量的數(shù)據(jù),那么相同的模型將表現(xiàn)得非常好,并給出更好的結(jié)果。



預(yù)測是有意義的,模型能夠預(yù)測疾病和非疾病數(shù)據(jù)。為了提高模型的性能,我們需要更多的數(shù)據(jù),以便我們的模型訓(xùn)練得更好,給出更好的輸出。
總結(jié)
我們能夠成功地為x射線圖像生成標(biāo)題(發(fā)現(xiàn)),并能夠通過帶有GRUs的基于全局注意力的編碼器-解碼器模型實(shí)現(xiàn)約28.3%的BLEU評分。由于我們擁有的數(shù)據(jù)非常少,而且偏向于非患病數(shù)據(jù),我們無法獲得非常好的BLEU得分,但如果我們有大量平衡的數(shù)據(jù),那么同一段代碼可以非常好地預(yù)測圖像的標(biāo)題。
改進(jìn):
可以使用BERT來獲得標(biāo)題嵌入,也可以使用BERT或者在解碼器中使用GPT-2或GPT-3來生成標(biāo)題,可以使用Transformer來代替基于注意力的編碼器-解碼器架構(gòu),獲取更多有疾病的x光圖像,因?yàn)樵摂?shù)據(jù)集中可獲得的大多數(shù)數(shù)據(jù)屬于“無疾病”類別。
本文的代碼如下:
https://github.com/skurnapally/Medical_Image_Captioning_on_Chest_X-Rays

分享
收藏
點(diǎn)贊
在看

