-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
utils.py 1.61 KiB
import datetime
import pandas as pd
import torch
from tqdm import tqdm
def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
print("\n" + "#" * 20)
print("Loading csv...")
rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header
chunk_list = []
if rows > nrows:
rows = nrows
chunksize = nrows
with tqdm(total=rows, desc='Rows read: ') as bar:
for chunk in pd.read_csv(csv_path, converters={'Y': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=rows):
chunk_list.append(chunk)
bar.update(len(chunk))
df = pd.concat((f for f in chunk_list), axis=0)
print("#" * 20)
return df
def format_time(elapsed):
'''
Takes a time in seconds and returns a string hh:mm:ss
'''
# Round to the nearest second.
elapsed_rounded = int(round(elapsed))
# Format as hh:mm:ss
return str(datetime.timedelta(seconds=elapsed_rounded))