aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--patchtree/cli.py51
-rw-r--r--patchtree/config.py3
-rw-r--r--patchtree/context.py83
-rw-r--r--patchtree/patch.py3
4 files changed, 100 insertions, 40 deletions
diff --git a/patchtree/cli.py b/patchtree/cli.py
index 9ddb4fb..bef0c0d 100644
--- a/patchtree/cli.py
+++ b/patchtree/cli.py
@@ -1,6 +1,7 @@
from dataclasses import fields
from sys import stderr
from pathlib import Path
+from argparse import ArgumentTypeError
from .context import Context
from .config import Config
@@ -16,6 +17,13 @@ def load_config() -> Config:
return Config(**init)
+def path_dir(path: str) -> Path:
+ out = Path(path)
+ if not out.is_dir():
+ raise ArgumentTypeError(f"not a directory: `{path}'")
+ return out
+
+
def parse_arguments(config: Config) -> Context:
parser = config.argument_parser(
prog="patchtree",
@@ -36,6 +44,7 @@ def parse_arguments(config: Config) -> Context:
parser.add_argument(
"-c",
"--context",
+ metavar="NUM",
help="lines of context in output diff",
type=int,
)
@@ -45,8 +54,32 @@ def parse_arguments(config: Config) -> Context:
help="output shebang in resulting patch",
action="store_true",
)
- parser.add_argument("target", help="target directory or archive")
- parser.add_argument("patch", help="patch input glob(s)", nargs="*")
+ parser.add_argument(
+ "-C",
+ "--root",
+ metavar="DIR",
+ help="patchset root directory",
+ type=path_dir,
+ )
+ parser.add_argument(
+ "-g",
+ "--glob",
+ help="enable globbing for input(s)",
+ action="store_true",
+ )
+ parser.add_argument(
+ "target",
+ metavar="TARGET",
+ help="target directory or archive",
+ type=Path,
+ )
+ parser.add_argument(
+ "patch",
+ metavar="PATCH",
+ help="patchset input(s)",
+ nargs="*",
+ type=str,
+ )
options = parser.parse_args()
@@ -70,24 +103,18 @@ def main():
context = parse_arguments(config)
- file_set: set[Path] = set()
- for pattern in context.options.patch:
- for path in Path(".").glob(pattern):
- if not path.is_file():
- continue
- file_set.add(path)
- files = sorted(file_set)
-
- if len(files) == 0:
+ if len(context.inputs) == 0:
print("no files to patch!", file=stderr)
return 0
config.header(config, context)
- for file in files:
+ for file in context.inputs:
patch = config.patch(config, file)
patch.write(context)
+ context.close()
+
return 0
diff --git a/patchtree/config.py b/patchtree/config.py
index fbe78e9..2f2373e 100644
--- a/patchtree/config.py
+++ b/patchtree/config.py
@@ -3,6 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from argparse import ArgumentParser
from importlib import metadata
+from pathlib import Path
from .context import Context
from .patch import Patch
@@ -95,7 +96,7 @@ class Config:
"""Whether to output a shebang line with the ``git patch`` command to apply
the patch."""
- default_patch_sources: list[str] = field(default_factory=list)
+ default_patch_sources: list[Path] = field(default_factory=list)
"""List of default sources."""
def __post_init__(self):
diff --git a/patchtree/context.py b/patchtree/context.py
index 2538849..cca419a 100644
--- a/patchtree/context.py
+++ b/patchtree/context.py
@@ -18,9 +18,9 @@ ZIP_CREATE_SYSTEM_UNX = 3
class FS:
- target: str
+ target: Path
- def __init__(self, target: str):
+ def __init__(self, target: Path):
self.target = target
def get_dir(self, dir: str) -> list[str]:
@@ -34,18 +34,15 @@ class FS:
class DiskFS(FS):
- path: Path
-
- def __init__(self, target: str):
+ def __init__(self, target):
super(DiskFS, self).__init__(target)
- self.path = Path(target)
def get_dir(self, dir: str) -> list[str]:
- here = self.path.joinpath(dir)
+ here = self.target.joinpath(dir)
return [path.name for path in here.iterdir()]
def get_content(self, file: str) -> str | None:
- here = self.path.joinpath(file)
+ here = self.target.joinpath(file)
if not here.exists():
return None
bytes = here.read_bytes()
@@ -55,7 +52,7 @@ class DiskFS(FS):
return ""
def get_mode(self, file: str) -> int:
- here = self.path.joinpath(file)
+ here = self.target.joinpath(file)
if not here.exists():
return 0
return here.stat().st_mode
@@ -65,9 +62,9 @@ class ZipFS(FS):
zip: ZipFile
files: dict[Path, ZipInfo] = {}
- def __init__(self, target: str):
+ def __init__(self, target):
super(ZipFS, self).__init__(target)
- self.zip = ZipFile(target)
+ self.zip = ZipFile(str(target))
for info in self.zip.infolist():
self.files[Path(info.filename)] = info
@@ -117,29 +114,61 @@ class Context:
fs: FS
output: IO
+ root: Path
+ target: Path
+ inputs: list[Path] = []
+ in_place: bool
+
config: Config
- options: Namespace
def __init__(self, config: Config, options: Namespace):
self.config = config
- self.options = options
+
+ self.root = options.root
+ self.target = options.target
+ self.in_place = options.in_place
+ self.inputs = self.collect_inputs(options)
self.fs = self.get_fs()
- self.output = self.get_output()
+ self.output = self.get_output(options)
- if self.options.in_place:
+ if self.in_place:
self.apply(True)
- def __del__(self):
+ def close(self):
# patch must have a trailing newline
self.output.write("\n")
self.output.flush()
- if self.options.in_place:
+ if self.in_place:
self.apply(False)
self.output.close()
+ def collect_inputs(self, options: Namespace) -> list[Path]:
+ inputs: set[Path] = set()
+
+ if len(inputs) == 0:
+ options.glob = True
+ options.patch = [str(Path(options.root or ".").joinpath("**"))]
+
+ if options.glob:
+ for pattern in options.patch:
+ for path in Path(".").glob(pattern):
+ if not path.is_file():
+ continue
+ inputs.add(path)
+ return sorted(inputs)
+ else:
+ for input in options.patch:
+ path = Path(input)
+ if not path.exists():
+ raise Exception(f"cannot open `{input}'")
+ if not path.is_file():
+ raise Exception(f"not a file: `{input}'")
+ inputs.add(path)
+ return list(inputs)
+
def get_dir(self, dir: str) -> list[str]:
return self.fs.get_dir(dir)
@@ -150,32 +179,32 @@ class Context:
return self.fs.get_mode(file)
def get_fs(self) -> FS:
- target: str = self.options.target
+ target = self.target
- if not path.exists(target):
+ if not target.exists():
raise Exception(f"cannot open `{target}'")
if path.isdir(target):
return DiskFS(target)
if is_zipfile(target):
- if self.options.in_place:
+ if self.in_place:
raise Exception("cannot edit zip in-place!")
return ZipFS(target)
raise Exception("cannot read `{target}'")
- def get_output(self) -> IO:
- if self.options.in_place:
- if self.options.out is not None:
+ def get_output(self, options: Namespace) -> IO:
+ if self.in_place:
+ if options.out is not None:
print("warning: --out is ignored when using --in-place", file=stderr)
return TemporaryFile("w+")
- if self.options.out is not None:
- if self.options.out == "-":
+ if options.out is not None:
+ if options.out == "-":
return stdout
else:
- return open(self.options.out, "w+")
+ return open(options.out, "w+")
return stdout
@@ -190,7 +219,7 @@ class Context:
return cmd
def apply(self, reverse: bool) -> None:
- location = cast(DiskFS, self.fs).path
+ location = cast(DiskFS, self.fs).target
cache = location.joinpath(".patchtree.diff")
cmd = self.get_apply_cmd()
diff --git a/patchtree/patch.py b/patchtree/patch.py
index 32d3b7f..bff7ee4 100644
--- a/patchtree/patch.py
+++ b/patchtree/patch.py
@@ -42,6 +42,9 @@ class Patch:
)
def write(self, context: Context) -> None:
+ if context.root is not None:
+ self.file = str(Path(self.file).relative_to(context.root))
+
diff = Diff(self.config, self.file)
diff.a = File(