From 84211ed0a45f92c25343a72452279138069d607b Mon Sep 17 00:00:00 2001
From: Alice Pain <alice.pain@ens.psl.eu>
Date: Wed, 13 Jul 2022 16:51:26 +0200
Subject: [PATCH] update

---
 align/README.md               |  5 ++-
 align/bil2mono.py             |  4 +-
 align/disrpt2readable.py      |  1 +
 align/train_model_baseline.py |  7 ++--
 align/train_model_rich.py     | 13 +++---
 code/ssplit/README.md         |  4 +-
 code/ssplit/tok2conllu.py     | 74 ++++++++++++++++++-----------------
 7 files changed, 60 insertions(+), 48 deletions(-)

diff --git a/align/README.md b/align/README.md
index ce6fecf..554a924 100644
--- a/align/README.md
+++ b/align/README.md
@@ -22,7 +22,8 @@ By default:
 - for ELMo models: Glove embeddings under **discut/embeddings**:
 
 `wget  -P embeddings/ "http://nlp.stanford.edu/data/glove.6B.zip"`
-`gunzip embeddings/glove.6B.zip`
+`cd embeddings`
+`unzip glove.6B.zip`
 
 # Compute alignment matrix
 
@@ -85,6 +86,6 @@ If an alignment matrix is given (only for baseline model, for now), the rotation
 
 - `visualize.py`: functions to visualize results using Pyplot and PCA.
 
-- `parse_corpus.py <corpus> [rich]`: parse a corpus and save useful information in `.npz` file (token ids, tokens, labels, + if rich: uposes, head ids, dependency relations).
+- `parse_corpus.py <corpus> [rich]`: parse a corpus and save useful information in `.npz` file (token ids, tokens, labels, + if rich: uposes, head ids, dependency relations). `<corpus>` must have the form `corpus_set.fmt` (where set is test/train/dev and fmt is tok/conllu).
 
 - `disrpt2readable.py <file>`: convert a DISRPT RST or SDRT file to an easily readable text file, where segments are indicated by '|'. Text file is saved under the same directory with suffix `.readable`.
diff --git a/align/bil2mono.py b/align/bil2mono.py
index 4767e7d..7f9f921 100644
--- a/align/bil2mono.py
+++ b/align/bil2mono.py
@@ -27,8 +27,10 @@ def to_mono(dico):
 
     with open(dico, "r") as rf:
         for line in rf:
-            split = line.strip().split(" ")
+            split = line.strip().split()
             if len(split) != 2: 
+                print(line)
+                print(split)
                 print(f'Format error in {dico} file')
                 sys.exit()
             words_src.add(split[0])
diff --git a/align/disrpt2readable.py b/align/disrpt2readable.py
index 1f68c59..4d29316 100644
--- a/align/disrpt2readable.py
+++ b/align/disrpt2readable.py
@@ -13,6 +13,7 @@ def readable(file_name):
         for tok, lab in zip(tokens, labels):
             if lab: f.write('| ')
             f.write(f'{tok} ')
+    print(f'Wrote file {output_file}.')
 
 def main():
     if len(sys.argv) != 2:
diff --git a/align/train_model_baseline.py b/align/train_model_baseline.py
index 2c7629c..2c7e8fd 100644
--- a/align/train_model_baseline.py
+++ b/align/train_model_baseline.py
@@ -70,12 +70,12 @@ class SentenceBatch():
         return BatchEncoding(dico)
     
 def generate_sentence_list(corpus, sset, fmt):
-    #move that part to parse_corpus.py
     parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
+    filename = f"{corpus}_{sset}.{fmt}"
     print(f"Looking for parsed file at {parsed_data}.")
     if not os.path.isfile(parsed_data):
         print("Parsed data not found.")
-        tok_ids, toks, labels = parse(corpus, False)
+        tok_ids, toks, labels = parse(filename, False)
     else: 
         data = np.load(parsed_data, allow_pickle = True)
         print("Loaded parsed data.")
@@ -182,6 +182,7 @@ def train(corpus, fmt):
         fp = 0 
         fn = 0
     
+        print(f"Epoch {epoch}")
         for sentence_batch in tqdm(dataloader):
             optimizer.zero_grad()
             label_batch = sentence_batch.labels
@@ -207,7 +208,7 @@ def train(corpus, fmt):
         precision = tp / (tp + fp) if (tp + fp != 0) else 'n/a'
         recall = tp / (tp + fn) if (tp + fn != 0) else 'n/a'
         f1 = 2 * (precision * recall) / (precision + recall) if (precision != 'n/a' and recall != 'n/a') else 'n/a'
-        print(f"Epoch {epoch}\nAcc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n")
+        print(f"Acc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n")
     
     print('Done training')
     output_file = save_model(model, 'baseline', corpus, params)
diff --git a/align/train_model_rich.py b/align/train_model_rich.py
index 0065447..32d4071 100644
--- a/align/train_model_rich.py
+++ b/align/train_model_rich.py
@@ -8,7 +8,9 @@ from torch import nn
 from torch.nn.utils.rnn import pad_sequence
 from torch.utils.data import DataLoader
 from tqdm import tqdm
+
 from train_model_baseline import SentenceBatch, add_cls_sep, toks_to_ids, make_labels, make_tok_types, make_tok_masks, save_model
+from parse_corpus import parse
 
 bert = 'bert-base-multilingual-cased'
 #bert_embeddings = BertModel.from_pretrained(bert)
@@ -107,12 +109,13 @@ def dh_list(tok_ids, idheads, s, e):
 
 def generate_rich_sentence_list(corpus, sset, fmt):
         parsed_data = os.path.join("parsed_data", f"parsedrich_{corpus}_{sset}.{fmt}.npz")
-        #print("PATH", parsed_data)
+        filename = f"{corpus}_{sset}.{fmt}"
         if not os.path.isfile(parsed_data):
-                print("you must parse the corpus before training it")
-                sys.exit()
-        data = np.load(parsed_data, allow_pickle = True)
-        tok_ids, toks, labels, upos, idheads, deprels = [data[f] for f in data.files]
+            print("Parsed data not found.")
+            tok_ids, toks, labels, upos, idheads, deprels = parse(filename, full_name=False, rich=True)
+        else: 
+            data = np.load(parsed_data, allow_pickle = True)
+            tok_ids, toks, labels, upos, idheads, deprels = [data[f] for f in data.files]
         
         unique_upos = np.unique(upos)
         unique_deprels = np.unique(deprels)
diff --git a/code/ssplit/README.md b/code/ssplit/README.md
index a941b17..5c8dc0c 100644
--- a/code/ssplit/README.md
+++ b/code/ssplit/README.md
@@ -1,6 +1,6 @@
 # Requirements 
 
-`ersatz` (`pip install ersatz`)
+`ersatz 0.0.1` : `pip install ersatz==0.0.1`
 
 # Usage
 
@@ -16,7 +16,7 @@ File must be a `.tok` file.
 
 Command-line usage:
 
-`ersatz --input <input.txt> --output <output.txt>`
+`ersatz <input.txt> > <output.txt>`
 
 Takes as input any text file and outputs the same text file with sentences separated by a line-break.
 
diff --git a/code/ssplit/tok2conllu.py b/code/ssplit/tok2conllu.py
index a223a0e..001b49b 100644
--- a/code/ssplit/tok2conllu.py
+++ b/code/ssplit/tok2conllu.py
@@ -2,15 +2,16 @@ import sys
 import pandas as pd
 import os
 
+"""This file doesn't quite work for now but it might just be an index error"""
+
 tab = "\t"
-space = " "
 
 def parse_file(f):
     """Take a .tok file and turn it into a sequence of token ids and tokens (.tok_seq). Token id precedes token."""
 
     column_names = ['tok_id', 'tok', '1', '2', '3', '4', '5', '6', '7', 'seg']
 
-    dataframe = pd.read_csv(f, names=column_names, comment="#", sep="\t",skipinitialspace=True)
+    dataframe = pd.read_csv(f, names=column_names, comment="#", sep="\t",skipinitialspace=True, quoting=3)
     tok_ids = dataframe['tok_id'].values
     toks = dataframe['tok'].values
 
@@ -26,54 +27,57 @@ def write_seq_file(f, tok_ids, toks):
 def parse_ssplit_file(f):
     """Take a .tok_ssplit file and return ids of sentence-beginning tokens."""
 
+    sstart_ids = []
     with open(f, "r") as rf:
-        sentences = rf.readlines()
-
-    sstart_ids = [0] * len(sentences)
-    for i, sentence in enumerate(sentences):
-        ids_toks = sentence.strip().split(space) 
-        sstart_ids[i] = ids_toks[0]
+        for sentence in rf:
+            ids_toks = sentence.strip().split() 
+            if len(ids_toks) > 0 and ids_toks[0].isnumeric() and int(ids_toks[0]) > 0: 
+                sstart_ids.append(ids_toks[0])
 
     return sstart_ids
 
 def make_ssplit(rf, wf, sstart_ids):
     """Write new file with sentence boundaries"""
-    with open(rf, "r") as f:
-        lines = f.readlines()
-
-    doc_id = None
-    next_sentence = 0 #index of token beginning next sentence in sstart_ids
-    sent_counter = 0
-
-    with open(wf, "w") as f:
-        for line in lines:
-            split = line.strip().split(tab)
-            tok_id = split[0]
-            if tok_id.startswith("#"):
-                doc_id = line
-                sent_counter = 0
-                f.write(line)
-            elif tok_id == sstart_ids[next_sentence]:
-                doc_id_nb = doc_id.strip().split("= ")[1]
-                if sent_counter: newline = "\n" 
-                else: newline = ""
-                sent_counter += 1
-                sent_id = "# sent_id = " + doc_id_nb + "-" + str(sent_counter)
-                f.write(newline + sent_id + "\n")
-                f.write(line)
-                next_sentence += 1
-            else:
-                f.write(line)
+    with open(rf, "r") as readf:
+        #lines = f.readlines()
+
+        #doc_id = None
+        next_sentence = 0 #index of token beginning next sentence in sstart_ids
+        #sent_counter = 0
+
+        with open(wf, "w") as writef:
+            for read_line in readf:
+
+                if read_line != '':
+                    split = read_line.strip().split(tab)
+                    tok_id = split[0]
+                    if tok_id.startswith("#"):
+                        #doc_id = read_line
+                        #sent_counter = 0
+                        writef.write(read_line)
+                    elif tok_id == sstart_ids[next_sentence]:
+                        #doc_id_nb = doc_id.strip().split("= ")[1]
+                        #if sent_counter: newline = "\n" 
+                        #else: newline = ""
+                        #sent_counter += 1
+                        #sent_id = "# sent_id = " + doc_id_nb + "-" + str(sent_counter)
+                        #writef.write(newline + sent_id + "\n")
+                        writef.write('\n')
+                        writef.write(read_line)
+                        next_sentence += 1
+                    else:
+                        writef.write(read_line)
 
 def t2c(f):
     dataframe, tok_ids, toks = parse_file(f)
     f_seq = f + "_seq"
     write_seq_file(f_seq, tok_ids, toks)
     f_ssplit = f + "_ssplit"
-    os.system(f"ersatz --input {f_seq} --output {f_ssplit}")
+    os.system(f"ersatz {f_seq} > {f_ssplit}")
     sstart_ids = parse_ssplit_file(f_ssplit)
     f_conllu = f + "_conllu"
     make_ssplit(f, f_conllu, sstart_ids)
+    print(f'Sentence-split tok file written at {f_conllu}.')
     os.system(f"rm {f_seq} {f_ssplit}") #remove temporary files
 
 def main():
-- 
GitLab