Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
disCut21
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
MELODI
AnDiAMO
discourseSegmentation
disCut21
Commits
78235faf
Commit
78235faf
authored
2 years ago
by
Alice Pain
Browse files
Options
Downloads
Patches
Plain Diff
minor
parent
f6e16e8f
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
trytorch/train_model_baseline.py
+193
-200
193 additions, 200 deletions
trytorch/train_model_baseline.py
with
193 additions
and
200 deletions
trytorch/train_model_baseline.py
+
193
−
200
View file @
78235faf
import
numpy
as
np
import
sys
import
os
import
argparse
from
transformers
import
BertTokenizer
,
BertModel
from
transformers.tokenization_utils_base
import
BatchEncoding
import
torch
from
torch
import
nn
from
torch.nn.utils.rnn
import
pad_sequence
from
torch.utils.data
import
DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from
tqdm
import
tqdm
bert
=
'
bert-base-multilingual-cased
'
#bert_embeddings = BertModel
.from_pretrained(bert)
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert
)
class
LSTM
(
nn
.
Module
):
def
__init__
(
self
,
batch_size
,
input_size
,
hidden_size
,
n
um
_layers
=
1
,
bidirectional
=
False
):
super
().
__init__
()
self
.
batch_size
=
batch_size
self
.
hidden_size
=
hidden_size
self
.
bert_embeddings
=
BertModel
.
from_pretrained
(
bert
)
self
.
lstm
=
nn
.
LSTM
(
input_size
,
hidden_size
,
n
um
_layers
,
batch_first
=
True
,
bidirectional
=
bidirectional
)
d
=
2
if
bidirectional
else
1
self
.
hiddenToLabel
=
nn
.
Linear
(
d
*
hidden_size
,
1
)
self
.
act
=
nn
.
Sigmoid
()
def
forward
(
self
,
batch
):
output
=
self
.
bert_embeddings
(
**
batch
)
output768
=
output
.
last_hidden_state
output64
,
(
last_hidden_state
,
last_cell_state
)
=
self
.
lstm
(
output768
)
#output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape)
#print("last_hidden_state", last_hidden_state.shape)
#print("last_cell_state", last_cell_state.shape)
output1
=
self
.
hiddenToLabel
(
output64
)
#print("output1=", output1.shape)
return
self
.
act
(
output1
[:,:,
0
])
#lstm_out, self.hidden = self.lstm(output, self.hidden)
def
__init__
(
self
,
batch_size
,
input_size
,
hidden_size
,
n_layers
=
1
,
bidirectional
=
False
):
super
().
__init__
()
self
.
batch_size
=
batch_size
self
.
hidden_size
=
hidden_size
self
.
bert_embeddings
=
BertModel
.
from_pretrained
(
bert
)
self
.
lstm
=
nn
.
LSTM
(
input_size
,
hidden_size
,
n_layers
,
batch_first
=
True
,
bidirectional
=
bidirectional
)
d
=
2
if
bidirectional
else
1
self
.
hiddenToLabel
=
nn
.
Linear
(
d
*
hidden_size
,
1
)
self
.
act
=
nn
.
Sigmoid
()
def
forward
(
self
,
batch
):
output
=
self
.
bert_embeddings
(
**
batch
)
output768
=
output
.
last_hidden_state
output64
,
(
last_hidden_state
,
last_cell_state
)
=
self
.
lstm
(
output768
)
#output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape)
#print("last_hidden_state", last_hidden_state.shape)
#print("last_cell_state", last_cell_state.shape)
output1
=
self
.
hiddenToLabel
(
output64
)
#print("output1=", output1.shape)
return
self
.
act
(
output1
[:,:,
0
])
#lstm_out, self.hidden = self.lstm(output, self.hidden)
class
SentenceBatch
():
def
__init__
(
self
,
sentence_ids
,
tokens
,
tok_ids
,
tok_types
,
tok_masks
,
labels
,
uposes
=
None
,
deprels
=
None
,
dheads
=
None
):
self
.
sentence_ids
=
sentence_ids
self
.
tokens
=
tokens
self
.
tok_ids
=
pad_sequence
(
tok_ids
,
batch_first
=
True
)
#bert token ids
self
.
tok_types
=
pad_sequence
(
tok_types
,
batch_first
=
True
)
self
.
tok_masks
=
pad_sequence
(
tok_masks
,
batch_first
=
True
)
self
.
labels
=
pad_sequence
(
labels
,
batch_first
=
True
)
if
uposes
is
not
None
:
self
.
uposes
=
pad_sequence
(
uposes
,
batch_first
=
True
)
else
:
self
.
uposes
=
None
if
deprels
is
not
None
:
self
.
deprels
=
pad_sequence
(
deprels
,
batch_first
=
True
)
else
:
self
.
deprels
=
None
if
dheads
is
not
None
:
self
.
dheads
=
pad_sequence
(
dheads
,
batch_first
=
True
)
else
:
self
.
dheads
=
None
self
.
labels
=
pad_sequence
(
labels
,
batch_first
=
True
)
def
getBatchEncoding
(
self
):
dico
=
{
'
input_ids
'
:
self
.
tok_ids
,
'
token_type_ids
'
:
self
.
tok_types
,
'
attention_mask
'
:
self
.
tok_masks
}
return
BatchEncoding
(
dico
)
def
__init__
(
self
,
sentence_ids
,
tokens
,
tok_ids
,
tok_types
,
tok_masks
,
labels
,
uposes
=
None
,
deprels
=
None
,
dheads
=
None
):
self
.
sentence_ids
=
sentence_ids
self
.
tokens
=
tokens
self
.
tok_ids
=
pad_sequence
(
tok_ids
,
batch_first
=
True
)
#bert token ids
self
.
tok_types
=
pad_sequence
(
tok_types
,
batch_first
=
True
)
self
.
tok_masks
=
pad_sequence
(
tok_masks
,
batch_first
=
True
)
self
.
labels
=
pad_sequence
(
labels
,
batch_first
=
True
)
if
uposes
is
not
None
:
self
.
uposes
=
pad_sequence
(
uposes
,
batch_first
=
True
)
else
:
self
.
uposes
=
None
if
deprels
is
not
None
:
self
.
deprels
=
pad_sequence
(
deprels
,
batch_first
=
True
)
else
:
self
.
deprels
=
None
if
dheads
is
not
None
:
self
.
dheads
=
pad_sequence
(
dheads
,
batch_first
=
True
)
else
:
self
.
dheads
=
None
self
.
labels
=
pad_sequence
(
labels
,
batch_first
=
True
)
def
getBatchEncoding
(
self
):
dico
=
{
'
input_ids
'
:
self
.
tok_ids
,
'
token_type_ids
'
:
self
.
tok_types
,
'
attention_mask
'
:
self
.
tok_masks
}
return
BatchEncoding
(
dico
)
def
generate_sentence_list
(
corpus
,
sset
,
fmt
):
#move that part to parse_corpus.py
parsed_data
=
os
.
path
.
join
(
"
parsed_data
"
,
f
"
parsed_
{
corpus
}
_
{
sset
}
.
{
fmt
}
.npz
"
)
#print("PATH", parsed_data)
if
not
os
.
path
.
isfile
(
parsed_data
):
print
(
"
you must parse the corpus before training it
"
)
sys
.
exit
()
data
=
np
.
load
(
parsed_data
,
allow_pickle
=
True
)
tok_ids
,
toks
,
labels
=
[
data
[
f
]
for
f
in
data
.
files
]
sentences
=
[]
if
fmt
==
'
conllu
'
:
previous_i
=
0
for
index
,
tok_id
in
enumerate
(
tok_ids
):
if
index
>
0
and
(
tok_id
==
1
or
tok_id
==
'
1
'
or
(
isinstance
(
tok_id
,
str
)
and
tok_id
.
startswith
(
'
1-
'
))):
if
index
-
previous_i
<=
510
:
sentences
.
append
((
toks
[
previous_i
:
index
],
labels
[
previous_i
:
index
]))
previous_i
=
index
else
:
sep
=
previous_i
+
510
sentences
.
append
((
toks
[
previous_i
:
sep
],
labels
[
previous_i
:
sep
]))
if
sep
-
previous_i
>
510
:
print
(
"
still too long sentence...
"
)
sys
.
exit
()
sentences
.
append
((
toks
[
sep
:
index
],
labels
[
sep
:
index
]))
else
:
max_length
=
510
nb_toks
=
len
(
tok_ids
)
for
i
in
range
(
0
,
nb_toks
-
max_length
,
max_length
):
#add slices of 510 tokens
sentences
.
append
((
toks
[
i
:(
i
+
max_length
)],
labels
[
i
:(
i
+
max_length
)]))
sentences
.
append
((
toks
[
-
(
nb_toks
%
max_length
):],
labels
[
-
(
nb_toks
%
max_length
):]))
indexed_sentences
=
[
0
]
*
len
(
sentences
)
for
i
,
sentence
in
enumerate
(
sentences
):
indexed_sentences
[
i
]
=
(
i
,
sentence
)
return
indexed_sentences
#move that part to parse_corpus.py
parsed_data
=
os
.
path
.
join
(
"
parsed_data
"
,
f
"
parsed_
{
corpus
}
_
{
sset
}
.
{
fmt
}
.npz
"
)
#print("PATH", parsed_data)
if
not
os
.
path
.
isfile
(
parsed_data
):
print
(
"
you must parse the corpus before training it
"
)
sys
.
exit
()
data
=
np
.
load
(
parsed_data
,
allow_pickle
=
True
)
tok_ids
,
toks
,
labels
=
[
data
[
f
]
for
f
in
data
.
files
]
sentences
=
[]
if
fmt
==
'
conllu
'
:
previous_i
=
0
for
index
,
tok_id
in
enumerate
(
tok_ids
):
if
index
>
0
and
(
tok_id
==
1
or
tok_id
==
'
1
'
or
(
isinstance
(
tok_id
,
str
)
and
tok_id
.
startswith
(
'
1-
'
))):
if
index
-
previous_i
<=
510
:
sentences
.
append
((
toks
[
previous_i
:
index
],
labels
[
previous_i
:
index
]))
previous_i
=
index
else
:
sep
=
previous_i
+
510
sentences
.
append
((
toks
[
previous_i
:
sep
],
labels
[
previous_i
:
sep
]))
if
sep
-
previous_i
>
510
:
print
(
"
still too long sentence...
"
)
sys
.
exit
()
sentences
.
append
((
toks
[
sep
:
index
],
labels
[
sep
:
index
]))
else
:
max_length
=
510
nb_toks
=
len
(
tok_ids
)
for
i
in
range
(
0
,
nb_toks
-
max_length
,
max_length
):
#add slices of 510 tokens
sentences
.
append
((
toks
[
i
:(
i
+
max_length
)],
labels
[
i
:(
i
+
max_length
)]))
sentences
.
append
((
toks
[
-
(
nb_toks
%
max_length
):],
labels
[
-
(
nb_toks
%
max_length
):]))
indexed_sentences
=
[
0
]
*
len
(
sentences
)
for
i
,
sentence
in
enumerate
(
sentences
):
indexed_sentences
[
i
]
=
(
i
,
sentence
)
return
indexed_sentences
def
add_cls_sep
(
sentence
):
return
[
'
[CLS]
'
]
+
list
(
sentence
)
+
[
'
[SEP]
'
]
return
[
'
[CLS]
'
]
+
list
(
sentence
)
+
[
'
[SEP]
'
]
def
toks_to_ids
(
sentence
):
#print("sentence=", sentence)
tokens
=
[
'
[CLS]
'
]
+
list
(
sentence
)
+
[
'
[SEP]
'
]
#print("tokens=", tokens)
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert
)
return
torch
.
tensor
(
tokenizer
.
convert_tokens_to_ids
(
tokens
))
#len(tokens)
#print("token_ids=", token_ids)
#print("sentence=", sentence)
tokens
=
[
'
[CLS]
'
]
+
list
(
sentence
)
+
[
'
[SEP]
'
]
#print("tokens=", tokens)
return
torch
.
tensor
(
tokenizer
.
convert_tokens_to_ids
(
tokens
))
#len(tokens)
#print("token_ids=", token_ids)
def
make_labels
(
sentence
):
zero
=
np
.
array
([
0
])
add_two
=
np
.
concatenate
((
np
.
concatenate
((
zero
,
sentence
)),
zero
))
#add label 0 for [CLS] and [SEP]
return
torch
.
from_numpy
(
add_two
).
float
()
zero
=
np
.
array
([
0
])
add_two
=
np
.
concatenate
((
np
.
concatenate
((
zero
,
sentence
)),
zero
))
#add label 0 for [CLS] and [SEP]
return
torch
.
from_numpy
(
add_two
).
float
()
def
make_tok_types
(
l
):
return
torch
.
zeros
(
l
,
dtype
=
torch
.
int32
)
return
torch
.
zeros
(
l
,
dtype
=
torch
.
int32
)
def
make_tok_masks
(
l
):
return
torch
.
ones
(
l
,
dtype
=
torch
.
int32
)
return
torch
.
ones
(
l
,
dtype
=
torch
.
int32
)
def
collate_batch
(
batch
):
sentence_ids
,
token_batch
,
label_batch
=
[
i
for
i
,
(
_
,
_
)
in
batch
],
[
j
for
_
,
(
j
,
_
)
in
batch
],
[
k
for
_
,
(
_
,
k
)
in
batch
]
#mappings = [make_mapping(sentence) for sentence in token_batch]
labels
=
[
make_labels
(
sentence
)
for
sentence
in
label_batch
]
tokens
=
[
add_cls_sep
(
sentence
)
for
sentence
in
token_batch
]
tok_ids
=
[
toks_to_ids
(
sentence
)
for
sentence
in
token_batch
]
lengths
=
[
len
(
toks
)
for
toks
in
tok_ids
]
tok_types
=
[
make_tok_types
(
l
)
for
l
in
lengths
]
tok_masks
=
[
make_tok_masks
(
l
)
for
l
in
lengths
]
return
SentenceBatch
(
sentence_ids
,
tokens
,
tok_ids
,
tok_types
,
tok_masks
,
labels
)
sentence_ids
,
token_batch
,
label_batch
=
[
0
]
*
len
(
batch
),
[
0
]
*
len
(
batch
),
[
0
]
*
len
(
batch
)
for
i
,
(
ids
,
(
toks
,
labs
))
in
enumerate
(
batch
):
sentence_ids
[
i
]
=
ids
token_batch
[
i
]
=
toks
label_batch
[
i
]
=
labs
labels
=
[
make_labels
(
sentence
)
for
sentence
in
label_batch
]
tokens
=
[
add_cls_sep
(
sentence
)
for
sentence
in
token_batch
]
tok_ids
=
[
torch
.
tensor
(
tokenizer
.
convert_tokens_to_ids
(
sentence
))
for
sentence
in
tokens
]
lengths
=
[
len
(
toks
)
for
toks
in
tok_ids
]
tok_types
=
[
make_tok_types
(
l
)
for
l
in
lengths
]
tok_masks
=
[
make_tok_masks
(
l
)
for
l
in
lengths
]
return
SentenceBatch
(
sentence_ids
,
tokens
,
tok_ids
,
tok_types
,
tok_masks
,
labels
)
def
train
(
corpus
,
fmt
):
print
(
f
'
starting training of
{
corpus
}
in format
{
fmt
}
...
'
)
data
=
generate_sentence_list
(
corpus
,
'
train
'
,
fmt
)
torch
.
manual_seed
(
1
)
#PARAMETERS
batch_size
=
32
if
fmt
==
'
conllu
'
else
4
num_epochs
=
10
lr
=
0.0001
reg
=
0.0005
dropout
=
0.01
bidirectional
=
True
params
=
{
'
fm
'
:
fmt
,
'
bs
'
:
batch_size
,
'
ne
'
:
num_epochs
,
'
lr
'
:
lr
,
'
rg
'
:
reg
,
'
do
'
:
dropout
,
'
bi
'
:
bidirectional
}
dataloader
=
DataLoader
(
data
,
batch_size
=
batch_size
,
shuffle
=
True
,
collate_fn
=
collate_batch
)
model
=
LSTM
(
batch_size
,
768
,
64
,
num_layers
=
1
,
bidirectional
=
bidirectional
)
loss_fn
=
nn
.
BCELoss
()
#BCELoss
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
lr
,
weight_decay
=
reg
)
model
.
train
()
l
=
len
(
dataloader
.
dataset
)
for
epoch
in
range
(
num_epochs
):
total_acc
=
0
total_loss
=
0
tp
=
0
fp
=
0
fn
=
0
for
sentence_batch
in
tqdm
(
dataloader
):
optimizer
.
zero_grad
()
label_batch
=
sentence_batch
.
labels
#print("label_batch", label_batch.shape)
#print("tok_ids", sentence_batch.tok_ids.shape)
pred
=
model
(
sentence_batch
.
getBatchEncoding
())
#print("pred", pred.shape)
loss
=
loss_fn
(
pred
,
label_batch
)
loss
.
backward
()
optimizer
.
step
()
pred_binary
=
(
pred
>=
0.5
).
float
()
for
i
in
range
(
label_batch
.
size
(
0
)):
for
j
in
range
(
label_batch
.
size
(
1
)):
if
pred_binary
[
i
,
j
]
==
1.
:
if
label_batch
[
i
,
j
]
==
1.
:
tp
+=
1
else
:
fp
+=
1
elif
label_batch
[
i
,
j
]
==
1.
:
fn
+=
1
#print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
#nb_1 = pred_binary.sum()
#print("nb predicted 1=", nb_1)
sum_score
=
((
pred_binary
==
label_batch
).
float
().
sum
().
item
())
/
label_batch
.
size
(
1
)
assert
(
sum_score
<=
batch_size
)
total_acc
+=
sum_score
total_loss
+=
loss
.
item
()
#*label_batch.size(0)
f1
=
tp
/
(
tp
+
(
fp
+
fn
)
/
2
)
print
(
f
"
Epoch
{
epoch
}
Accuracy
{
total_acc
/
l
}
Loss
{
total_loss
/
l
}
F1
{
f1
}
"
)
print
(
'
done training
'
)
output_file
=
save_model
(
model
,
'
baseline
'
,
corpus
,
params
)
print
(
f
'
model saved at
{
output_file
}
'
)
print
(
f
'
starting training of
{
corpus
}
in format
{
fmt
}
...
'
)
data
=
generate_sentence_list
(
corpus
,
'
train
'
,
fmt
)
torch
.
manual_seed
(
1
)
#PARAMETERS
batch_size
=
32
if
fmt
==
'
conllu
'
else
4
n_epochs
=
10
lr
=
0.0001
reg
=
0.001
dropout
=
0.1
n_layers
=
1
n_hidden
=
64
bidirectional
=
True
params
=
{
'
fm
'
:
fmt
,
'
nl
'
:
n_layers
,
'
nh
'
:
n_hidden
,
'
bs
'
:
batch_size
,
'
ne
'
:
n_epochs
,
'
lr
'
:
lr
,
'
rg
'
:
reg
,
'
do
'
:
dropout
,
'
bi
'
:
bidirectional
}
dataloader
=
DataLoader
(
data
,
batch_size
=
batch_size
,
shuffle
=
True
,
collate_fn
=
collate_batch
)
model
=
LSTM
(
batch_size
,
768
,
n_hidden
,
n_layers
,
bidirectional
=
bidirectional
)
loss_fn
=
nn
.
BCELoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
lr
,
weight_decay
=
reg
)
model
.
train
()
l
=
len
(
dataloader
.
dataset
)
for
epoch
in
range
(
n_epochs
):
total_acc
=
0
total_loss
=
0
tp
=
0
fp
=
0
fn
=
0
for
sentence_batch
in
tqdm
(
dataloader
):
optimizer
.
zero_grad
()
label_batch
=
sentence_batch
.
labels
pred
=
model
(
sentence_batch
.
getBatchEncoding
())
loss
=
loss_fn
(
pred
,
label_batch
)
loss
.
backward
()
optimizer
.
step
()
pred_binary
=
(
pred
>=
0.5
).
float
()
for
i
in
range
(
label_batch
.
size
(
0
)):
for
j
in
range
(
label_batch
.
size
(
1
)):
if
pred_binary
[
i
,
j
]
==
1.
:
if
label_batch
[
i
,
j
]
==
1.
:
tp
+=
1
else
:
fp
+=
1
elif
label_batch
[
i
,
j
]
==
1.
:
fn
+=
1
sum_score
=
((
pred_binary
==
label_batch
).
float
().
sum
().
item
())
/
label_batch
.
size
(
1
)
assert
(
sum_score
<=
batch_size
)
total_acc
+=
sum_score
total_loss
+=
loss
.
item
()
#*label_batch.size(0)
f1
=
tp
/
(
tp
+
(
fp
+
fn
)
/
2
)
print
(
f
"
Epoch
{
epoch
}
Accuracy
{
total_acc
/
l
}
Loss
{
total_loss
/
l
}
F1
{
f1
}
"
)
print
(
'
done training
'
)
output_file
=
save_model
(
model
,
'
baseline
'
,
corpus
,
params
)
print
(
f
'
model saved at
{
output_file
}
'
)
def
save_model
(
model
,
model_type
,
corpus
,
params
):
models_dir
=
'
saved_models
'
if
not
os
.
path
.
isdir
(
models_dir
):
os
.
system
(
f
'
mkdir
{
models_dir
}
'
)
corpus_dir
=
os
.
path
.
join
(
models_dir
,
corpus
)
if
not
os
.
path
.
isdir
(
corpus_dir
):
os
.
system
(
f
'
mkdir
{
corpus_dir
}
'
)
model_file
=
f
"
{
model_type
}
_
{
corpus
}
_
{
params
[
'
fm
'
]
}
_
{
params
[
'
bs
'
]
}
_
{
params
[
'
ne
'
]
}
_
{
params
[
'
lr
'
]
}
_
{
params
[
'
rg
'
]
}
_
{
params
[
'
do
'
]
}
_
{
params
[
'
bi
'
]
}
.pth
"
output_file
=
os
.
path
.
join
(
corpus_dir
,
model_file
)
if
not
os
.
path
.
isfile
(
output_file
):
torch
.
save
(
model
,
output_file
)
else
:
print
(
"
Model already
reg
ist
er
s. Do you wish to overwrite? (Y/n)
"
)
user
=
input
()
if
user
==
""
or
user
==
"
Y
"
or
user
==
"
y
"
:
torch
.
save
(
model
,
output_file
)
return
output_file
models_dir
=
'
saved_models
'
if
not
os
.
path
.
isdir
(
models_dir
):
os
.
system
(
f
'
mkdir
{
models_dir
}
'
)
corpus_dir
=
os
.
path
.
join
(
models_dir
,
corpus
)
if
not
os
.
path
.
isdir
(
corpus_dir
):
os
.
system
(
f
'
mkdir
{
corpus_dir
}
'
)
model_file
=
f
"
{
model_type
}
_
{
corpus
}
_
{
params
[
'
nl
'
]
}
_
{
params
[
'
nh
'
]
}
_
{
params
[
'
fm
'
]
}
_
{
params
[
'
bs
'
]
}
_
{
params
[
'
ne
'
]
}
_
{
params
[
'
lr
'
]
}
_
{
params
[
'
rg
'
]
}
_
{
params
[
'
do
'
]
}
_
{
params
[
'
bi
'
]
}
.pth
"
output_file
=
os
.
path
.
join
(
corpus_dir
,
model_file
)
if
not
os
.
path
.
isfile
(
output_file
):
torch
.
save
(
model
,
output_file
)
else
:
print
(
"
Model already
ex
ists. Do you wish to overwrite? (Y/n)
"
)
user
=
input
()
if
user
==
""
or
user
==
"
Y
"
or
user
==
"
y
"
:
torch
.
save
(
model
,
output_file
)
return
output_file
def
main
():
if
len
(
sys
.
argv
)
<
2
or
len
(
sys
.
argv
)
>
3
:
print
(
"
usage: train_model_baseline.py <corpus> [<conllu/tok>]
"
)
sys
.
exit
()
corpus
=
sys
.
argv
[
1
]
fmt
=
'
conllu
'
if
len
(
sys
.
argv
)
==
3
:
fmt
=
sys
.
argv
[
2
]
if
fmt
!=
'
conllu
'
and
fmt
!=
'
tok
'
:
print
(
"
usage: train_model_baseline.py <corpus> [<conllu/tok>]
"
)
sys
.
exit
()
train
(
corpus
,
fmt
)
parser
=
argparse
.
ArgumentParser
(
description
=
'
Train baseline model
'
)
parser
.
add_argument
(
'
--corpus
'
,
required
=
True
,
help
=
'
corpus to train
'
)
parser
.
add_argument
(
'
--format
'
,
default
=
'
conllu
'
,
help
=
'
tok or conllu
'
)
params
=
parse
.
parse_args
()
train
(
params
.
corpus
,
params
.
format
)
if
__name__
==
'
__main__
'
:
main
()
main
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment