diff --git a/roll/roll.py b/roll/roll.py index 03b0488..d6ac444 100644 --- a/roll/roll.py +++ b/roll/roll.py @@ -12,13 +12,13 @@ ROLL_PATTERN = re.compile(r"(\d+)d(\d+)([+-]\d+)?") @dataclasses.dataclass(frozen=True) class Roll: - """A roll of one or more dice""" + """a roll of one or more dice""" - dice_count: int - sides: int + dice_count: int = 1 + sides: int = 20 modifier: int | None = None - def __post_init__(self): + def __post_init__(self) -> None: if self.dice_count < 1: msg = "dice must be greater than 0" raise ValueError(msg) @@ -28,7 +28,7 @@ class Roll: @classmethod def from_str(cls, value: str) -> Self: - """Parse a Roll from it's short representation, e.g. 2d6 or 1d20-2""" + """parse a Roll from its short representation, e.g. 2d6 or 1d20-2""" match = ROLL_PATTERN.fullmatch(value) if match is None: msg = f"expected {value!r} to match pattern {ROLL_PATTERN.pattern!r}" @@ -37,21 +37,21 @@ class Roll: return cls(int(dice_count), int(sides), int(modifier) if modifier else None) def modifier_str(self) -> str: - """Return the modifier as a string""" + """return the modifier as a string""" if self.modifier is None: return "" sign = "+" if self.modifier > 0 else "" return f"{sign}{self.modifier}" def to_str(self) -> str: - """Return the short representation of a roll, e.g. 3d4 or 2d20+3""" + """return the short representation of a roll, e.g. 3d4 or 2d20+3""" return f"{self.dice_count}d{self.sides}{self.modifier_str()}" def modify(self, modifier: int) -> Self: - """Return a new Roll with the given modifier""" + """return a new Roll with the given modifier""" return dataclasses.replace(self, modifier=modifier) def throw(self) -> Throw: - """Throw the dice""" + """throw the dice""" throw = [random.randint(1, self.sides) for _ in range(self.dice_count)] - return Throw(throw, self.modifier) + return Throw(throw, self.sides, self.modifier) diff --git a/roll/throw.py b/roll/throw.py index 0c00d76..ce38119 100644 --- a/roll/throw.py +++ b/roll/throw.py @@ -1,12 +1,25 @@ import dataclasses +D20_MAX = 20 + @dataclasses.dataclass(frozen=True) class Throw: results: list[int] - modifier: int | None + sides: int + modifier: int | None = None @property def total(self) -> int: - """Calculate the total of the throw, accounting for the modifier""" + """calculate the total of the throw, accounting for the modifier""" return sum(self.results) + (self.modifier or 0) + + @property + def is_critical_hit(self) -> bool: + """check if the throw is a 20 on a 1d20""" + return self.sides == D20_MAX and self.results == [D20_MAX] + + @property + def is_critical_miss(self) -> bool: + """check if the throw is a 1 on a 1d20""" + return self.sides == D20_MAX and self.results == [1] diff --git a/tests/roll_test.py b/tests/roll_test.py index af5bf3b..eaec67c 100644 --- a/tests/roll_test.py +++ b/tests/roll_test.py @@ -1,4 +1,4 @@ -import pytest # type: ignore (TODO: figure out why pyright can't import pytest) +import pytest from roll.roll import Roll @@ -11,7 +11,8 @@ def test_roll_validation(): @pytest.mark.parametrize( - ("roll", "expected"), [(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")] + ("roll", "expected"), + [(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")], ) def test_str_roundtrip(roll: Roll, expected: str): assert roll.to_str() == expected @@ -29,7 +30,7 @@ def test_modify(): modified_roll = roll.modify(3) assert modified_roll == Roll(2, 20, 3) assert roll == Roll(2, 20) - assert modified_roll is not roll + assert modified_roll is not roll and modified_roll != roll @pytest.mark.parametrize("n", list(range(1, 5))) diff --git a/tests/throw_test.py b/tests/throw_test.py new file mode 100644 index 0000000..a94d7de --- /dev/null +++ b/tests/throw_test.py @@ -0,0 +1,29 @@ +from roll.throw import Throw + + +def test_throw(): + throw = Throw([1, 2, 3], sides=4) + assert throw.total == 6 + assert not throw.is_critical_hit + assert not throw.is_critical_miss + + +def test_throw_with_modifier(): + throw = Throw([1, 2, 3], sides=4, modifier=5) + assert throw.total == 11 + assert not throw.is_critical_hit + assert not throw.is_critical_miss + + +def test_critical_hit(): + throw = Throw([20], sides=20) + assert throw.total == 20 + assert throw.is_critical_hit + assert not throw.is_critical_miss + + +def test_critical_miss(): + throw = Throw([1], sides=20) + assert throw.total == 1 + assert not throw.is_critical_hit + assert throw.is_critical_miss