"""
Diverse containers (IntPool, RecordList, ClassDict)
"""
from __future__ import annotations
from collections import namedtuple as _namedtuple
import dataclasses
from keyword import iskeyword as _iskeyword
from collections import OrderedDict as _OrderedDict
import itertools
from typing import Sequence
[docs]
class FullError(Exception):
pass
[docs]
class EmptyError(Exception):
pass
[docs]
class IntPool:
"""
A pool of intergers
A pool will contain the integers in the range [start, start + capacity)
Args:
capacity: the capacity (size) of the pool.
start: first element
fixedsize: if True, this pool cannot be extended. Otherwise, when
a pop operation would result in an empty pool the pool is doubled
in size, adding the new items. A pool cannot be extended by adding
elements outside its range, so a push operation might still fail
if the item is outside the range of the pool
Example
~~~~~~~
>>> from emlib.containers import IntPool
>>> pool = IntPool(10)
>>> token = pool.pop()
>>> len(pool)
9
>>> pool.push(token)
>>> len(pool)
10
>>> pool.push(4)
ValueError: token 4 already in pool
"""
def __init__(self, capacity: int, start=0, fixedsize=True):
self.capacity = capacity
self.pool = set(range(start, start+capacity))
self.tokenrange = (start, start+capacity)
self.fixedsize = fixedsize
[docs]
def pop(self) -> int:
"""
Take an item from the pool
"""
if not self.pool:
if self.fixedsize:
raise EmptyError("This pool is empty")
else:
self._extend(self.capacity)
return self.pool.pop()
[docs]
def push(self, token: int) -> None:
"""
Return an item to the pool
"""
if token in self.pool:
raise ValueError(f"token {token} already in pool")
if not self.tokenrange[0] <= token < self.tokenrange[1]:
raise ValueError("This token is not part of the pool")
assert len(self.pool) < self.capacity
self.pool.add(token)
def __contains__(self, item):
return item in self.pool
def __len__(self) -> int:
return len(self.pool)
def _extend(self, extrasize: int) -> None:
cap = self.capacity
self.capacity += extrasize
start, end = self.tokenrange
self.tokenrange = start, end + extrasize
self.pool.update(range(end, end+extrasize))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# RecordList: a list of namedtuples
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
class RecordList(list):
"""
A list of namedtuples / dataclasses
Args:
data: A seq of namedtuples or dataclass objects. A seq. of tuples or lists
is also possible. In that case, fields must be given
fields: a string as passed to namedtuple
itemname: The name of each row (optional), overrides the name given for the namedtuples
convert: True, data will be converted to namedtuples if they are not already
Example
~~~~~~~
.. code::
# generate a RecordList of measures
>>> from dataclasses import dataclass
>>> @dataclass
... class Measure:
... tempo: int
... timesig: tuple[int, int]
>>> measures = [Measure(tempo, timesig) for tempo, timesig in [(60, (3, 4)), (60, (4, 4)), (72, (5, 8))]]
>>> measurelist = RecordList(measures)
>>> measurelist.get_column('tempo')
[60, 60, 72]
"""
def __init__(self,
data: list,
fields: str | Sequence[str] = '',
itemname=''):
if not data and not fields:
raise ValueError("data is empty, fields must be given")
def _is_list_of_namedtuples(data) -> bool:
return isinstance(data, list) and len(data) > 0 and hasattr(data[0], "_fields")
if itemname:
self.item_name = itemname
elif _is_list_of_namedtuples(data):
self.item_name = data[0].__class__.__name__
else:
self.item_name = "Row"
self._name = "{name}s".format(name=self.item_name)
data_already_namedtuples = _is_list_of_namedtuples(data)
if data_already_namedtuples and fields is None:
fieldstup = data[0]._fields
else:
if not fields:
raise ValueError("A seq. of namedtuples must be given or a seq. of tuples. "
"For the latter, 'fields' must be specified."
f"(got {type(data[0])}")
fieldstup = fields.split() if isinstance(fields, str) else tuple(fields)
fieldstup = _validate_fields(fieldstup)
list.__init__(self, data)
if data_already_namedtuples:
try:
make = data[0]._make
self._item_constructor = lambda *args: make(args)
except AttributeError:
self._item_constructor = None
else:
self._item_constructor = None
self.columns = fieldstup
def __repr__(self):
import tabulate
return tabulate.tabulate(
self, self.columns, disable_numparse=True, showindex=True
)
def _repr_html_(self):
return self.to_html()
[docs]
def to_html(self, showindex=True) -> str:
import tabulate
return tabulate.tabulate(
self,
self.columns,
tablefmt="html",
disable_numparse=True,
showindex=showindex,
)
def __getitem__(self, val) -> tuple | RecordList:
if isinstance(val, int):
return list.__getitem__(self, val)
else:
recs = list.__getitem__(self, val)
return RecordList(recs)
[docs]
def reversed(self) -> RecordList:
"""
return a reversed copy of self
"""
return RecordList(list(reversed(self)), itemname=self.item_name)
[docs]
def copy(self) -> RecordList:
"""
return a copy of self
"""
return RecordList(self, itemname=self.item_name)
# ######################################################
# columns
# ######################################################
[docs]
def get_column(self, column: int | str) -> list:
"""
Return a column by name or index as a list of values.
Raises ValueError if column is not found
Args:
column: the column to get, as index or column name
Returns:
a list with the values
"""
if isinstance(column, int):
index = column
elif isinstance(column, str):
try:
index = self.columns.index(column)
except ValueError:
raise ValueError(f"column {column} not found")
else:
raise TypeError("column should be a label (str), or an index (int)")
return [item[index] for item in self]
[docs]
def add_column(self, name: str, data, itemname: str = '', missing=None) -> RecordList:
"""
Return a new RecordList with the added data as a column
If len(data) < len(self), pad data with missing
Args:
name: the name of the new column
data: the data of the column
itemname: the name of each item
missing: value to use when padding is needed
Returns:
the resulting RecordList
"""
itemname = itemname or self.item_name
columns = tuple(self.columns) + (name,)
padded = itertools.chain(data, itertools.repeat(missing))
newdata = [row + (x,) for row, x in zip(self, padded)]
r = RecordList(newdata, columns, itemname)
return r
[docs]
def remove_column(self, colname: str) -> RecordList:
"""
Return a new RecordList with the column removed
Args:
colname: the name of the column to remove
Returns:
the resulting RecordList
"""
if colname not in self.columns:
return self
return self.get_columns([col for col in self.columns if col != colname])
#######################################################
# operations with other RecordLists
#######################################################
[docs]
def merge_with(self, other: RecordList) -> RecordList:
"""
A new list is returned with a union of the fields of self and other
If there are fields in common, other prevails (similar to dict.update)
If self and other have a different number of rows, the lowest
is taken.
Args:
other: the RecordList to merge with
Returns:
the merged RecordList
"""
if not isinstance(other, list) or not hasattr(other, "columns"):
raise TypeError("other should be a RecordList")
columns = list(self.columns)
for othercol in other.columns:
if othercol not in columns:
columns.append(othercol)
coldata = []
for col in columns:
if col in other.columns:
coldata.append(other.get_column(col))
else:
coldata.append(self.get_column(col))
return RecordList(list(zip(*coldata)), columns)
[docs]
def get_columns(self, columns: list[str]) -> RecordList:
"""
Returns a new RecordList with the selected columns
Args:
columns: a list of column names
Returns:
the resulting RecordList
"""
data_columns = [self.get_column(column) for column in columns]
data = zip(*data_columns)
constructor = _namedtuple(self.item_name, columns)
items = [constructor(*row) for row in data]
return RecordList(items)
@property
def item_constructor(self):
if self._item_constructor is None:
c = _namedtuple(self.item_name, self.columns)
self._item_constructor = c
return self._item_constructor
[docs]
def sort_by(self, column: str) -> None:
"""
Sort this RecordList (in place) by the given column
Args:
column: the column name to use to sort this RecordList
"""
self.sort(key=lambda item: getattr(item, column))
[docs]
@classmethod
def from_csv(cls, csvfile: str) -> RecordList:
"""
Create a new RecordList with the data in csvfile
"""
from .csvtools import readcsv
rows = readcsv(csvfile)
return cls(rows)
[docs]
def to_csv(self, outfile: str) -> None:
"""
Write the data in this RecordList as a csv file
"""
from .csvtools import writecsv
writecsv(self, outfile, column_names=self.columns)
[docs]
@classmethod
def from_dataframe(cls, dataframe, itemname="row") -> RecordList:
"""
create a RecordList from a pandas.DataFrame
"""
columns = _validate_fields(list(dataframe.keys()))
Row = _namedtuple(itemname, columns)
out = []
for i in range(dataframe.shape[0]):
row = list(dataframe.irow(i))
out.append(Row(*row))
return cls(out, itemname=itemname)
[docs]
def to_dataframe(self):
"""
create a pandas.DataFrame from this RecordList
"""
try:
import pandas
return pandas.DataFrame(list(self), columns=self.columns)
except ImportError:
raise ImportError("pandas is needed to export to pandas.DataFrame!")
def _validate_fields(field_names: list[str]) -> list[str]:
"""
Validate the given field names
Args:
field_names: a list of strings to be used as attributes.
Example
=======
.. code::
# Numbers are not valid identifiers
# an object cannot have non-unique attributes
>>> _validate_fields(["0", "field", "field"])
['_0', 'field', '_2']
"""
names = list(map(str, list(field_names)))
seen = set()
for i, name in enumerate(names):
if (
not all(c.isalnum() or c == "_" for c in name)
or _iskeyword(name)
or not name
or name[0].isdigit()
or name.startswith("_")
or name in seen
):
names[i] = "field%d" % i
seen.add(name)
return names