diff --git a/README.md b/README.md index 05ed83f..43738b1 100644 --- a/README.md +++ b/README.md @@ -22,16 +22,9 @@ total | 14 $ roll 1d20 1 ..... 1 total | 1 -critical miss! - -$ roll 1d20+a # or 1d20+advantage -1a .... 19 -1b .... 2 -total | 19 - -$ roll -> 1d20 ->> total: 12 -> +5 ->> total + 5: 17 ``` + +## todo + +- [ ] roll with (dis)advantage +- [ ] interactive rolling mode diff --git a/pyproject.toml b/pyproject.toml index e5fdfff..b257eae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,10 +61,9 @@ dependencies = [ "black>=23.1.0", "pyright>=1.1.319", "ruff>=0.0.243", - "pytest", # for types in tests ] [tool.hatch.envs.lint.scripts] -typing = "pyright --project pyproject.toml --dependencies {args:roll tests}" +typing = "pyright --project pyproject.toml {args:roll tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", diff --git a/roll/cli/__init__.py b/roll/cli/__init__.py index 061f4b4..f6e9286 100644 --- a/roll/cli/__init__.py +++ b/roll/cli/__init__.py @@ -1,6 +1,8 @@ import click from roll.__about__ import __version__ +from roll.cli.roll_param import ROLL +from roll.roll import Roll @click.group( @@ -8,5 +10,41 @@ from roll.__about__ import __version__ invoke_without_command=True, ) @click.version_option(version=__version__, prog_name="roll") -def roll(): - click.echo("Hello world!") +@click.argument("rolls", nargs=-1, type=ROLL) +def roll(rolls: list[Roll]): + """Throw each roll specified in ROLLS and print the results. + + Rolls are specified as + + DdS[(+|-)M] + + where D = # of dice, S = sides per die, and M = optional modifier. + + Example usage: + + \b + $ roll 2d20+3 3d6-1 + rolling 2d20+3: + 1: | 2 + 2: | 13 + mod: +3 + total: 18 + \b + rolling 3d6-1: + 1: | 3 + 2: | 3 + 3: | 1 + mod: -1 + total: 6 + + """ + for roll in rolls: + click.echo() + click.echo(f"rolling {roll.to_str()}:") + throw = roll.throw() + for i, result in enumerate(throw.results): + click.echo(f"{i + 1}:\t| {result: >3}") + if roll.modifier: + mod = roll.modifier_str() + click.echo(f"mod:\t {mod: >4}") + click.echo(f"total:\t {throw.total: >3}") diff --git a/roll/cli/roll_param.py b/roll/cli/roll_param.py new file mode 100644 index 0000000..8db7699 --- /dev/null +++ b/roll/cli/roll_param.py @@ -0,0 +1,19 @@ +import click + +from roll.roll import Roll + + +class RollParam(click.ParamType): + name = "roll" + + def convert(self, value: str | Roll, param: click.Parameter | None, ctx: click.Context | None) -> Roll: + """Parse a Roll from a command line string.""" + if isinstance(value, Roll): + return value + try: + return Roll.from_str(value) + except Exception as e: + self.fail(f"invalid roll: {value!r}, caused by {e}", param, ctx) + + +ROLL = RollParam() diff --git a/roll/roll.py b/roll/roll.py new file mode 100644 index 0000000..03b0488 --- /dev/null +++ b/roll/roll.py @@ -0,0 +1,57 @@ +import dataclasses +import random +import re +from typing import Self + +from roll.throw import Throw + +MIN_SIDES = 2 + +ROLL_PATTERN = re.compile(r"(\d+)d(\d+)([+-]\d+)?") + + +@dataclasses.dataclass(frozen=True) +class Roll: + """A roll of one or more dice""" + + dice_count: int + sides: int + modifier: int | None = None + + def __post_init__(self): + if self.dice_count < 1: + msg = "dice must be greater than 0" + raise ValueError(msg) + if self.sides < MIN_SIDES: + msg = "sides must be greater than 1" + raise ValueError(msg) + + @classmethod + def from_str(cls, value: str) -> Self: + """Parse a Roll from it's short representation, e.g. 2d6 or 1d20-2""" + match = ROLL_PATTERN.fullmatch(value) + if match is None: + msg = f"expected {value!r} to match pattern {ROLL_PATTERN.pattern!r}" + raise ValueError(msg) + dice_count, sides, modifier = match.groups() + return cls(int(dice_count), int(sides), int(modifier) if modifier else None) + + def modifier_str(self) -> str: + """Return the modifier as a string""" + if self.modifier is None: + return "" + sign = "+" if self.modifier > 0 else "" + return f"{sign}{self.modifier}" + + def to_str(self) -> str: + """Return the short representation of a roll, e.g. 3d4 or 2d20+3""" + return f"{self.dice_count}d{self.sides}{self.modifier_str()}" + + def modify(self, modifier: int) -> Self: + """Return a new Roll with the given modifier""" + return dataclasses.replace(self, modifier=modifier) + + def throw(self) -> Throw: + """Throw the dice""" + throw = [random.randint(1, self.sides) for _ in range(self.dice_count)] + return Throw(throw, self.modifier) diff --git a/roll/throw.py b/roll/throw.py new file mode 100644 index 0000000..0c00d76 --- /dev/null +++ b/roll/throw.py @@ -0,0 +1,12 @@ +import dataclasses + + +@dataclasses.dataclass(frozen=True) +class Throw: + results: list[int] + modifier: int | None + + @property + def total(self) -> int: + """Calculate the total of the throw, accounting for the modifier""" + return sum(self.results) + (self.modifier or 0) diff --git a/tests/roll_test.py b/tests/roll_test.py new file mode 100644 index 0000000..af5bf3b --- /dev/null +++ b/tests/roll_test.py @@ -0,0 +1,41 @@ +import pytest # type: ignore (TODO: figure out why pyright can't import pytest) + +from roll.roll import Roll + + +def test_roll_validation(): + with pytest.raises(ValueError): + Roll(0, 20) + with pytest.raises(ValueError): + Roll(1, 1) + + +@pytest.mark.parametrize( + ("roll", "expected"), [(Roll(1, 20), "1d20"), (Roll(2, 20, 3), "2d20+3"), (Roll(4, 6, -3), "4d6-3")] +) +def test_str_roundtrip(roll: Roll, expected: str): + assert roll.to_str() == expected + assert Roll.from_str(expected) == roll + + +@pytest.mark.parametrize("roll", ["d90", "0d0", "-1d-1", "aba", "000", "1d1d1d"]) +def test_from_bad_str(roll: str): + with pytest.raises(ValueError): + Roll.from_str(roll) + + +def test_modify(): + roll = Roll(2, 20) + modified_roll = roll.modify(3) + assert modified_roll == Roll(2, 20, 3) + assert roll == Roll(2, 20) + assert modified_roll is not roll + + +@pytest.mark.parametrize("n", list(range(1, 5))) +@pytest.mark.parametrize("sides", list(range(2, 100))) +def test_throw(n: int, sides: int): + roll = Roll(n, sides) + throw = roll.throw() + assert len(throw.results) == n + assert all(1 <= result <= sides for result in throw.results)