Update roll data structures

main
mat ess 2023-08-13 14:53:51 -04:00
parent e985ca230c
commit f192f1f1f3
4 changed files with 58 additions and 15 deletions

View File

@ -12,13 +12,13 @@ ROLL_PATTERN = re.compile(r"(\d+)d(\d+)([+-]\d+)?")
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Roll: class Roll:
"""A roll of one or more dice""" """a roll of one or more dice"""
dice_count: int dice_count: int = 1
sides: int sides: int = 20
modifier: int | None = None modifier: int | None = None
def __post_init__(self): def __post_init__(self) -> None:
if self.dice_count < 1: if self.dice_count < 1:
msg = "dice must be greater than 0" msg = "dice must be greater than 0"
raise ValueError(msg) raise ValueError(msg)
@ -28,7 +28,7 @@ class Roll:
@classmethod @classmethod
def from_str(cls, value: str) -> Self: def from_str(cls, value: str) -> Self:
"""Parse a Roll from it's short representation, e.g. 2d6 or 1d20-2""" """parse a Roll from its short representation, e.g. 2d6 or 1d20-2"""
match = ROLL_PATTERN.fullmatch(value) match = ROLL_PATTERN.fullmatch(value)
if match is None: if match is None:
msg = f"expected {value!r} to match pattern {ROLL_PATTERN.pattern!r}" msg = f"expected {value!r} to match pattern {ROLL_PATTERN.pattern!r}"
@ -37,21 +37,21 @@ class Roll:
return cls(int(dice_count), int(sides), int(modifier) if modifier else None) return cls(int(dice_count), int(sides), int(modifier) if modifier else None)
def modifier_str(self) -> str: def modifier_str(self) -> str:
"""Return the modifier as a string""" """return the modifier as a string"""
if self.modifier is None: if self.modifier is None:
return "" return ""
sign = "+" if self.modifier > 0 else "" sign = "+" if self.modifier > 0 else ""
return f"{sign}{self.modifier}" return f"{sign}{self.modifier}"
def to_str(self) -> str: def to_str(self) -> str:
"""Return the short representation of a roll, e.g. 3d4 or 2d20+3""" """return the short representation of a roll, e.g. 3d4 or 2d20+3"""
return f"{self.dice_count}d{self.sides}{self.modifier_str()}" return f"{self.dice_count}d{self.sides}{self.modifier_str()}"
def modify(self, modifier: int) -> Self: def modify(self, modifier: int) -> Self:
"""Return a new Roll with the given modifier""" """return a new Roll with the given modifier"""
return dataclasses.replace(self, modifier=modifier) return dataclasses.replace(self, modifier=modifier)
def throw(self) -> Throw: def throw(self) -> Throw:
"""Throw the dice""" """throw the dice"""
throw = [random.randint(1, self.sides) for _ in range(self.dice_count)] throw = [random.randint(1, self.sides) for _ in range(self.dice_count)]
return Throw(throw, self.modifier) return Throw(throw, self.sides, self.modifier)

View File

@ -1,12 +1,25 @@
import dataclasses import dataclasses
D20_MAX = 20
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Throw: class Throw:
results: list[int] results: list[int]
modifier: int | None sides: int
modifier: int | None = None
@property @property
def total(self) -> int: def total(self) -> int:
"""Calculate the total of the throw, accounting for the modifier""" """calculate the total of the throw, accounting for the modifier"""
return sum(self.results) + (self.modifier or 0) return sum(self.results) + (self.modifier or 0)
@property
def is_critical_hit(self) -> bool:
"""check if the throw is a 20 on a 1d20"""
return self.sides == D20_MAX and self.results == [D20_MAX]
@property
def is_critical_miss(self) -> bool:
"""check if the throw is a 1 on a 1d20"""
return self.sides == D20_MAX and self.results == [1]

View File

@ -1,4 +1,4 @@
import pytest # type: ignore (TODO: figure out why pyright can't import pytest) import pytest
from roll.roll import Roll from roll.roll import Roll
@ -11,7 +11,8 @@ def test_roll_validation():
@pytest.mark.parametrize( @pytest.mark.parametrize(
("roll", "expected"), [(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")] ("roll", "expected"),
[(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")],
) )
def test_str_roundtrip(roll: Roll, expected: str): def test_str_roundtrip(roll: Roll, expected: str):
assert roll.to_str() == expected assert roll.to_str() == expected
@ -29,7 +30,7 @@ def test_modify():
modified_roll = roll.modify(3) modified_roll = roll.modify(3)
assert modified_roll == Roll(2, 20, 3) assert modified_roll == Roll(2, 20, 3)
assert roll == Roll(2, 20) assert roll == Roll(2, 20)
assert modified_roll is not roll assert modified_roll is not roll and modified_roll != roll
@pytest.mark.parametrize("n", list(range(1, 5))) @pytest.mark.parametrize("n", list(range(1, 5)))

29
tests/throw_test.py Normal file
View File

@ -0,0 +1,29 @@
from roll.throw import Throw
def test_throw():
throw = Throw([1, 2, 3], sides=4)
assert throw.total == 6
assert not throw.is_critical_hit
assert not throw.is_critical_miss
def test_throw_with_modifier():
throw = Throw([1, 2, 3], sides=4, modifier=5)
assert throw.total == 11
assert not throw.is_critical_hit
assert not throw.is_critical_miss
def test_critical_hit():
throw = Throw([20], sides=20)
assert throw.total == 20
assert throw.is_critical_hit
assert not throw.is_critical_miss
def test_critical_miss():
throw = Throw([1], sides=20)
assert throw.total == 1
assert not throw.is_critical_hit
assert throw.is_critical_miss