Skip to content
Snippets Groups Projects
Commit d581e05a authored by Julien Rabault's avatar Julien Rabault
Browse files

Don't use bert loss

parent 5fc882ea
Branches
No related tags found
1 merge request!1Draft: Master
......@@ -131,10 +131,10 @@ X, Y1, Y2, Z, vocabulary, vnorm, partsofspeech1, partsofspeech2, superset, maxle
df = pd.DataFrame(columns = ["X", "Y1", "Y2", "Z"])
df['X'] = X
df['Y1'] = Y1
df['Y2'] = Y2
df['Z'] = Z
df['X'] = X[:len(X)-1]
df['Y1'] = Y1[:len(X)-1]
df['Y2'] = Y2[:len(X)-1]
df['Z'] = Z[:len(X)-1]
df.to_csv("../m2_dataset_V2.csv", index=False)
......
......@@ -15768,4 +15768,3 @@ X,Y1,Y2,Z
L' effet indésirable le plus fréquent avec Angiox ( observé chez plus d' un patient sur 10 ) est le saignement bénin .,"['DET', 'NC', 'ADJ', 'DET', 'ADV', 'ADJ', 'P', 'NPP', 'PONCT', 'VPP', 'P', 'ADV', 'P', 'DET', 'NC', 'P', 'PRO', 'PONCT', 'V', 'DET', 'NC', 'ADJ', 'PONCT']","['DET:ART', 'NOM', 'ADJ', 'DET:ART', 'ADV', 'ADJ', 'PRP', 'NAM', 'PUN', 'VER:pper', 'PRP', 'ADV', 'PRP', 'DET:ART', 'NOM', 'PRP', 'NUM', 'PUN', 'VER:pres', 'DET:ART', 'NOM', 'ADJ', 'PUN']","['dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,n,n),dl(0,n,n))', 'dr(0,dl(0,n,n),dl(0,n,n))', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,np),dl(0,n,n))', 'dl(0,n,n)', 'dr(0,dl(1,dl(0,n,n),dl(0,n,n)),np)', 'dr(0,np,pp_de)', 'dr(0,pp_de,np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'np', 'let', 'dr(0,dl(0,np,s),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dl(0,s,txt)']"
" Pour avoir le détail de tous les effets indésirables observés lors de l' utilisation de Angiox , voir la notice .","['P', 'VINF', 'DET', 'NC', 'P', 'ADV', 'DET', 'NC', 'ADJ', 'VPP', 'ADV', 'P', 'DET', 'NC', 'P', 'NPP', 'PONCT', 'VINF', 'DET', 'NC', 'PONCT']","['PRP', 'VER:infi', 'DET:ART', 'NOM', 'PRP', 'ADV', 'DET:ART', 'NOM', 'ADJ', 'VER:pper', 'ADV', 'PRP', 'DET:ART', 'NOM', 'PRP', 'NAM', 'PUN', 'VER:infi', 'DET:ART', 'NOM', 'PUN']","['dr(0,dr(0,dl(0,np,s),dl(0,np,s)),dl(0,np,s_inf))', 'dr(0,dl(0,np,s_inf),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),np)', 'dr(0,np,np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dl(0,n,n)', 'dr(0,dl(1,dl(0,n,n),dl(0,n,n)),pp_de)', 'dr(0,pp_de,np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),np)', 'np', 'let', 'dr(0,dl(0,np,s_inf),np)', 'dr(0,np,n)', 'n', 'dl(0,dl(0,np,s),txt)']"
" Angiox ne doit pas être utilisé chez les personnes pouvant présenter une hypersensibilité ( allergie ) à la bivalirudine , aux autres hirudines , ou à l' un des autres composants constituant Angiox .","['NPP', 'ADV', 'V', 'ADV', 'VINF', 'VPP', 'P', 'DET', 'NC', 'VPR', 'VINF', 'DET', 'NC', 'PONCT', 'NC', 'PONCT', 'P', 'DET', 'NC', 'PONCT', 'P+D', 'ADJ', 'NC', 'PONCT', 'CC', 'P', 'DET', 'NC', 'P+D', 'ADJ', 'NC', 'VPR', 'NPP', 'PONCT']","['NAM', 'ADV', 'VER:pres', 'ADV', 'VER:infi', 'VER:pper', 'PRP', 'DET:ART', 'NOM', 'VER:ppre', 'VER:infi', 'DET:ART', 'NOM', 'PUN', 'NOM', 'PUN', 'PRP', 'DET:ART', 'NOM', 'PUN', 'PRP:det', 'ADJ', 'NOM', 'PUN', 'KON', 'PRP', 'DET:ART', 'NUM', 'PRP:det', 'ADJ', 'NOM', 'VER:ppre', 'NAM', 'PUN']","['np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),dl(0,np,s_inf))', 'dl(1,s,s)', 'dr(0,dl(0,np,s_inf),dl(0,np,s_pass))', 'dl(0,np,s_pass)', 'dr(0,dl(1,s,s),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),dl(0,np,s_inf))', 'dr(0,dl(0,np,s_inf),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),n)', 'n', 'let', 'dr(0,dl(0,n,n),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dr(0,dl(0,n,n),n)', 'dr(0,n,n)', 'n', 'let', 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dr(0,dl(0,n,n),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),n)', 'dr(0,n,n)', 'n', 'dr(0,dl(0,n,n),np)', 'np', 'dl(0,s,txt)']"
,[],[],[]
......@@ -9,6 +9,7 @@ import datetime
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoTokenizer
......@@ -55,6 +56,7 @@ class SuperTagger:
self.model = None
self.optimizer = None
self.loss = nn.CrossEntropyLoss(ignore_index=0)
self.epoch_i = 0
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -121,6 +123,8 @@ class SuperTagger:
pred = self.model.predict((sents_tokenized_t, sents_mask_t))
print(pred)
return self.tags_tokenizer.convert_ids_to_tags(pred.detach())
def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=32, tensorboard=False,
......@@ -208,10 +212,15 @@ class SuperTagger:
self.optimizer.zero_grad()
loss, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
_, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
predictions = torch.argmax(logit, dim=2).detach().cpu().numpy()
label_ids = targets.cpu().numpy()
print()
#torch.nn.functional.one_hot(targets).long()
# torch.argmax(logit)
loss = self.loss(torch.transpose(logit, 1, 2),Variable(targets))
acc = categorical_accuracy(predictions, label_ids)
......@@ -243,7 +252,9 @@ class SuperTagger:
b_sents_mask = batch[1].to(self.device)
b_symbols_tokenized = batch[2].to(self.device)
loss, logits = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
_, logits = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
loss = self.loss(torch.transpose(logits, 1, 2), Variable(b_symbols_tokenized))
predictions = torch.argmax(logits, dim=2).detach().cpu().numpy()
label_ids = b_symbols_tokenized.cpu().numpy()
......
......@@ -13,29 +13,29 @@ class SentencesTokenizer():
def fit_transform_tensors(self, sents):
# , return_tensors = 'pt'
temp = self.tokenizer(sents, padding=True, return_offsets_mapping = True)
len_sent_max = len(temp['attention_mask'][0])
input_ids = np.ones((len(sents),len_sent_max))
attention_mask = np.zeros((len(sents),len_sent_max))
for i in range(len(temp['offset_mapping'])):
h = 1
input_ids[i][0] = self.tokenizer.cls_token_id
attention_mask[i][0] = 1
for j in range (1,len_sent_max-1):
if temp['offset_mapping'][i][j][1] != temp['offset_mapping'][i][j+1][0]:
input_ids[i][h] = temp['input_ids'][i][j]
attention_mask[i][h] = 1
h += 1
input_ids[i][h] = self.tokenizer.eos_token_id
attention_mask[i][h] = 1
input_ids = torch.tensor(input_ids).long()
attention_mask = torch.tensor(attention_mask)
return input_ids, attention_mask
temp = self.tokenizer(sents, padding=True, return_offsets_mapping = True, return_tensors = 'pt')
#
# len_sent_max = len(temp['attention_mask'][0])
#
# input_ids = np.ones((len(sents),len_sent_max))
# attention_mask = np.zeros((len(sents),len_sent_max))
#
# for i in range(len(temp['offset_mapping'])):
# h = 1
# input_ids[i][0] = self.tokenizer.cls_token_id
# attention_mask[i][0] = 1
# for j in range (1,len_sent_max-1):
# if temp['offset_mapping'][i][j][1] != temp['offset_mapping'][i][j+1][0]:
# input_ids[i][h] = temp['input_ids'][i][j]
# attention_mask[i][h] = 1
# h += 1
# input_ids[i][h] = self.tokenizer.eos_token_id
# attention_mask[i][h] = 1
#
# input_ids = torch.tensor(input_ids).long()
# attention_mask = torch.tensor(attention_mask)
return temp["input_ids"], temp["attention_mask"]
def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False):
return self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=skip_special_tokens)
......@@ -6,11 +6,13 @@ from SuperTagger.Utils.SentencesTokenizer import SentencesTokenizer
from SuperTagger.Utils.SymbolTokenizer import SymbolTokenizer
from SuperTagger.Utils.utils import read_csv_pgbar
def categorical_accuracy(preds, truth):
flat_preds = preds[:len(truth)].flatten()
flat_labels = truth.flatten()
return np.sum(flat_preds == flat_labels) / len(flat_labels)
def load_obj(name):
with open(name + '.pkl', 'rb') as f:
import pickle
......@@ -19,16 +21,13 @@ def load_obj(name):
file_path = 'Datasets/m2_dataset_V2.csv'
df = read_csv_pgbar(file_path, 10)
texts = df['X'].tolist()
tags = df['Z'].tolist()
texts = texts[:1]
tags = tags[:1]
texts = texts[:3]
tags = tags[:3]
tagger = SuperTagger()
......@@ -52,12 +51,20 @@ tagger.load_weights("models/model_check.pt")
pred = tagger.predict(texts)
print(tags)
print(tags[1])
print()
print(pred[0])
print(pred[1])
print(pred[0][0] == tags[0])
def categorical_accuracy(preds, truth):
flat_preds = preds.flatten()
flat_labels = truth.flatten()
good_label = 0
for i in range(len(flat_preds)):
if flat_labels[i] == flat_preds[i] and flat_labels[i] != 0:
good_label += 1
print(np.sum(pred[0][:len(tags)] == tags) / len(tags))
return good_label / len(flat_labels)
print(categorical_accuracy(np.array(pred), np.array(tags)))
......@@ -11,7 +11,7 @@ def load_obj(name):
file_path = 'Datasets/m2_dataset_V2.csv'
df = read_csv_pgbar(file_path,100)
df = read_csv_pgbar(file_path,50)
texts = df['X'].tolist()
......@@ -31,7 +31,7 @@ tagger = SuperTagger()
tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
tagger.train(texts,tags,validation_rate=0,tensorboard=True,checkpoint=True)
tagger.train(texts,tags,validation_rate=0.1,tensorboard=True,checkpoint=True)
pred = tagger.predict(test_s)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment