Skip to content
Snippets Groups Projects
utils.py 1.61 KiB
import datetime

import pandas as pd
import torch
from tqdm import tqdm


def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
    max_size = sequences[0].size()
    trailing_dims = max_size[1:]
    if batch_first:
        out_dims = (len(sequences), max_len) + trailing_dims
    else:
        out_dims = (max_len, len(sequences)) + trailing_dims

    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        # use index notation to prevent duplicate references to the tensor
        if batch_first:
            out_tensor[i, :length, ...] = tensor
        else:
            out_tensor[:length, i, ...] = tensor

    return out_tensor


def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
    print("\n" + "#" * 20)
    print("Loading csv...")

    rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1  # minus the header
    chunk_list = []

    if rows > nrows:
        rows = nrows
        chunksize = nrows

    with tqdm(total=rows, desc='Rows read: ') as bar:
        for chunk in pd.read_csv(csv_path, converters={'Y': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=rows):
            chunk_list.append(chunk)
            bar.update(len(chunk))

    df = pd.concat((f for f in chunk_list), axis=0)
    print("#" * 20)

    return df


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round(elapsed))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))