Update roll data structures
parent
e985ca230c
commit
f192f1f1f3
20
roll/roll.py
20
roll/roll.py
|
@ -12,13 +12,13 @@ ROLL_PATTERN = re.compile(r"(\d+)d(\d+)([+-]\d+)?")
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Roll:
|
class Roll:
|
||||||
"""A roll of one or more dice"""
|
"""a roll of one or more dice"""
|
||||||
|
|
||||||
dice_count: int
|
dice_count: int = 1
|
||||||
sides: int
|
sides: int = 20
|
||||||
modifier: int | None = None
|
modifier: int | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
if self.dice_count < 1:
|
if self.dice_count < 1:
|
||||||
msg = "dice must be greater than 0"
|
msg = "dice must be greater than 0"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -28,7 +28,7 @@ class Roll:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_str(cls, value: str) -> Self:
|
def from_str(cls, value: str) -> Self:
|
||||||
"""Parse a Roll from it's short representation, e.g. 2d6 or 1d20-2"""
|
"""parse a Roll from its short representation, e.g. 2d6 or 1d20-2"""
|
||||||
match = ROLL_PATTERN.fullmatch(value)
|
match = ROLL_PATTERN.fullmatch(value)
|
||||||
if match is None:
|
if match is None:
|
||||||
msg = f"expected {value!r} to match pattern {ROLL_PATTERN.pattern!r}"
|
msg = f"expected {value!r} to match pattern {ROLL_PATTERN.pattern!r}"
|
||||||
|
@ -37,21 +37,21 @@ class Roll:
|
||||||
return cls(int(dice_count), int(sides), int(modifier) if modifier else None)
|
return cls(int(dice_count), int(sides), int(modifier) if modifier else None)
|
||||||
|
|
||||||
def modifier_str(self) -> str:
|
def modifier_str(self) -> str:
|
||||||
"""Return the modifier as a string"""
|
"""return the modifier as a string"""
|
||||||
if self.modifier is None:
|
if self.modifier is None:
|
||||||
return ""
|
return ""
|
||||||
sign = "+" if self.modifier > 0 else ""
|
sign = "+" if self.modifier > 0 else ""
|
||||||
return f"{sign}{self.modifier}"
|
return f"{sign}{self.modifier}"
|
||||||
|
|
||||||
def to_str(self) -> str:
|
def to_str(self) -> str:
|
||||||
"""Return the short representation of a roll, e.g. 3d4 or 2d20+3"""
|
"""return the short representation of a roll, e.g. 3d4 or 2d20+3"""
|
||||||
return f"{self.dice_count}d{self.sides}{self.modifier_str()}"
|
return f"{self.dice_count}d{self.sides}{self.modifier_str()}"
|
||||||
|
|
||||||
def modify(self, modifier: int) -> Self:
|
def modify(self, modifier: int) -> Self:
|
||||||
"""Return a new Roll with the given modifier"""
|
"""return a new Roll with the given modifier"""
|
||||||
return dataclasses.replace(self, modifier=modifier)
|
return dataclasses.replace(self, modifier=modifier)
|
||||||
|
|
||||||
def throw(self) -> Throw:
|
def throw(self) -> Throw:
|
||||||
"""Throw the dice"""
|
"""throw the dice"""
|
||||||
throw = [random.randint(1, self.sides) for _ in range(self.dice_count)]
|
throw = [random.randint(1, self.sides) for _ in range(self.dice_count)]
|
||||||
return Throw(throw, self.modifier)
|
return Throw(throw, self.sides, self.modifier)
|
||||||
|
|
|
@ -1,12 +1,25 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
|
D20_MAX = 20
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Throw:
|
class Throw:
|
||||||
results: list[int]
|
results: list[int]
|
||||||
modifier: int | None
|
sides: int
|
||||||
|
modifier: int | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total(self) -> int:
|
def total(self) -> int:
|
||||||
"""Calculate the total of the throw, accounting for the modifier"""
|
"""calculate the total of the throw, accounting for the modifier"""
|
||||||
return sum(self.results) + (self.modifier or 0)
|
return sum(self.results) + (self.modifier or 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical_hit(self) -> bool:
|
||||||
|
"""check if the throw is a 20 on a 1d20"""
|
||||||
|
return self.sides == D20_MAX and self.results == [D20_MAX]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical_miss(self) -> bool:
|
||||||
|
"""check if the throw is a 1 on a 1d20"""
|
||||||
|
return self.sides == D20_MAX and self.results == [1]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import pytest # type: ignore (TODO: figure out why pyright can't import pytest)
|
import pytest
|
||||||
|
|
||||||
from roll.roll import Roll
|
from roll.roll import Roll
|
||||||
|
|
||||||
|
@ -11,7 +11,8 @@ def test_roll_validation():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("roll", "expected"), [(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")]
|
("roll", "expected"),
|
||||||
|
[(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")],
|
||||||
)
|
)
|
||||||
def test_str_roundtrip(roll: Roll, expected: str):
|
def test_str_roundtrip(roll: Roll, expected: str):
|
||||||
assert roll.to_str() == expected
|
assert roll.to_str() == expected
|
||||||
|
@ -29,7 +30,7 @@ def test_modify():
|
||||||
modified_roll = roll.modify(3)
|
modified_roll = roll.modify(3)
|
||||||
assert modified_roll == Roll(2, 20, 3)
|
assert modified_roll == Roll(2, 20, 3)
|
||||||
assert roll == Roll(2, 20)
|
assert roll == Roll(2, 20)
|
||||||
assert modified_roll is not roll
|
assert modified_roll is not roll and modified_roll != roll
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", list(range(1, 5)))
|
@pytest.mark.parametrize("n", list(range(1, 5)))
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
from roll.throw import Throw
|
||||||
|
|
||||||
|
|
||||||
|
def test_throw():
|
||||||
|
throw = Throw([1, 2, 3], sides=4)
|
||||||
|
assert throw.total == 6
|
||||||
|
assert not throw.is_critical_hit
|
||||||
|
assert not throw.is_critical_miss
|
||||||
|
|
||||||
|
|
||||||
|
def test_throw_with_modifier():
|
||||||
|
throw = Throw([1, 2, 3], sides=4, modifier=5)
|
||||||
|
assert throw.total == 11
|
||||||
|
assert not throw.is_critical_hit
|
||||||
|
assert not throw.is_critical_miss
|
||||||
|
|
||||||
|
|
||||||
|
def test_critical_hit():
|
||||||
|
throw = Throw([20], sides=20)
|
||||||
|
assert throw.total == 20
|
||||||
|
assert throw.is_critical_hit
|
||||||
|
assert not throw.is_critical_miss
|
||||||
|
|
||||||
|
|
||||||
|
def test_critical_miss():
|
||||||
|
throw = Throw([1], sides=20)
|
||||||
|
assert throw.total == 1
|
||||||
|
assert not throw.is_critical_hit
|
||||||
|
assert throw.is_critical_miss
|
Loading…
Reference in New Issue