about summary refs log tree commit diff
path: root/library/compiler-builtins/ci/ci-util.py
diff options
context:
space:
mode:
Diffstat (limited to 'library/compiler-builtins/ci/ci-util.py')
-rwxr-xr-xlibrary/compiler-builtins/ci/ci-util.py126
1 files changed, 89 insertions, 37 deletions
diff --git a/library/compiler-builtins/ci/ci-util.py b/library/compiler-builtins/ci/ci-util.py
index 3437d304f48..c1db17c6c90 100755
--- a/library/compiler-builtins/ci/ci-util.py
+++ b/library/compiler-builtins/ci/ci-util.py
@@ -7,10 +7,12 @@ git history.
 
 import json
 import os
+import pprint
 import re
 import subprocess as sp
 import sys
 from dataclasses import dataclass
+from functools import cache
 from glob import glob
 from inspect import cleandoc
 from os import getenv
@@ -50,15 +52,6 @@ GIT = ["git", "-C", REPO_ROOT]
 DEFAULT_BRANCH = "master"
 WORKFLOW_NAME = "CI"  # Workflow that generates the benchmark artifacts
 ARTIFACT_PREFIX = "baseline-icount*"
-# Place this in a PR body to skip regression checks (must be at the start of a line).
-REGRESSION_DIRECTIVE = "ci: allow-regressions"
-# Place this in a PR body to skip extensive tests
-SKIP_EXTENSIVE_DIRECTIVE = "ci: skip-extensive"
-# Place this in a PR body to allow running a large number of extensive tests. If not
-# set, this script will error out if a threshold is exceeded in order to avoid
-# accidentally spending huge amounts of CI time.
-ALLOW_MANY_EXTENSIVE_DIRECTIVE = "ci: allow-many-extensive"
-MANY_EXTENSIVE_THRESHOLD = 20
 
 # Don't run exhaustive tests if these files change, even if they contaiin a function
 # definition.
@@ -70,7 +63,7 @@ IGNORE_FILES = [
 
 # libm PR CI takes a long time and doesn't need to run unless relevant files have been
 # changed. Anything matching this regex pattern will trigger a run.
-TRIGGER_LIBM_PR_CI = ".*(libm|musl).*"
+TRIGGER_LIBM_CI_FILE_PAT = ".*(libm|musl).*"
 
 TYPES = ["f16", "f32", "f64", "f128"]
 
@@ -80,6 +73,54 @@ def eprint(*args, **kwargs):
     print(*args, file=sys.stderr, **kwargs)
 
 
+@dataclass(init=False)
+class PrCfg:
+    """Directives that we allow in the commit body to control test behavior.
+
+    These are of the form `ci: foo`, at the start of a line.
+    """
+
+    # Skip regression checks (must be at the start of a line).
+    allow_regressions: bool = False
+    # Don't run extensive tests
+    skip_extensive: bool = False
+
+    # Allow running a large number of extensive tests. If not set, this script
+    # will error out if a threshold is exceeded in order to avoid accidentally
+    # spending huge amounts of CI time.
+    allow_many_extensive: bool = False
+
+    # Max number of extensive tests to run by default
+    MANY_EXTENSIVE_THRESHOLD: int = 20
+
+    # Run tests for `libm` that may otherwise be skipped due to no changed files.
+    always_test_libm: bool = False
+
+    # String values of directive names
+    DIR_ALLOW_REGRESSIONS: str = "allow-regressions"
+    DIR_SKIP_EXTENSIVE: str = "skip-extensive"
+    DIR_ALLOW_MANY_EXTENSIVE: str = "allow-many-extensive"
+    DIR_TEST_LIBM: str = "test-libm"
+
+    def __init__(self, body: str):
+        directives = re.finditer(r"^\s*ci:\s*(?P<dir_name>\S*)", body, re.MULTILINE)
+        for dir in directives:
+            name = dir.group("dir_name")
+            if name == self.DIR_ALLOW_REGRESSIONS:
+                self.allow_regressions = True
+            elif name == self.DIR_SKIP_EXTENSIVE:
+                self.skip_extensive = True
+            elif name == self.DIR_ALLOW_MANY_EXTENSIVE:
+                self.allow_many_extensive = True
+            elif name == self.DIR_TEST_LIBM:
+                self.always_test_libm = True
+            else:
+                eprint(f"Found unexpected directive `{name}`")
+                exit(1)
+
+        pprint.pp(self)
+
+
 @dataclass
 class PrInfo:
     """GitHub response for PR query"""
@@ -88,10 +129,21 @@ class PrInfo:
     commits: list[str]
     created_at: str
     number: int
+    cfg: PrCfg
 
     @classmethod
-    def load(cls, pr_number: int | str) -> Self:
-        """For a given PR number, query the body and commit list"""
+    def from_env(cls) -> Self | None:
+        """Create a PR object from the PR_NUMBER environment if set, `None` otherwise."""
+        pr_env = os.environ.get("PR_NUMBER")
+        if pr_env is not None and len(pr_env) > 0:
+            return cls.from_pr(pr_env)
+
+        return None
+
+    @classmethod
+    @cache  # Cache so we don't print info messages multiple times
+    def from_pr(cls, pr_number: int | str) -> Self:
+        """For a given PR number, query the body and commit list."""
         pr_info = sp.check_output(
             [
                 "gh",
@@ -104,13 +156,9 @@ class PrInfo:
             ],
             text=True,
         )
-        eprint("PR info:", json.dumps(pr_info, indent=4))
-        return cls(**json.loads(pr_info))
-
-    def contains_directive(self, directive: str) -> bool:
-        """Return true if the provided directive is on a line in the PR body"""
-        lines = self.body.splitlines()
-        return any(line.startswith(directive) for line in lines)
+        pr_json = json.loads(pr_info)
+        eprint("PR info:", json.dumps(pr_json, indent=4))
+        return cls(**json.loads(pr_info), cfg=PrCfg(pr_json["body"]))
 
 
 class FunctionDef(TypedDict):
@@ -207,26 +255,32 @@ class Context:
         """If this is a PR and no libm files were changed, allow skipping libm
         jobs."""
 
-        if self.is_pr():
-            return all(not re.match(TRIGGER_LIBM_PR_CI, str(f)) for f in self.changed)
+        # Always run on merge CI
+        if not self.is_pr():
+            return False
 
-        return False
+        pr = PrInfo.from_env()
+        assert pr is not None, "Is a PR but couldn't load PrInfo"
+
+        # Allow opting in to libm tests
+        if pr.cfg.always_test_libm:
+            return False
+
+        # By default, run if there are any changed files matching the pattern
+        return all(not re.match(TRIGGER_LIBM_CI_FILE_PAT, str(f)) for f in self.changed)
 
     def emit_workflow_output(self):
         """Create a JSON object a list items for each type's changed files, if any
         did change, and the routines that were affected by the change.
         """
 
-        pr_number = os.environ.get("PR_NUMBER")
         skip_tests = False
         error_on_many_tests = False
 
-        if pr_number is not None and len(pr_number) > 0:
-            pr = PrInfo.load(pr_number)
-            skip_tests = pr.contains_directive(SKIP_EXTENSIVE_DIRECTIVE)
-            error_on_many_tests = not pr.contains_directive(
-                ALLOW_MANY_EXTENSIVE_DIRECTIVE
-            )
+        pr = PrInfo.from_env()
+        if pr is not None:
+            skip_tests = pr.cfg.skip_extensive
+            error_on_many_tests = not pr.cfg.allow_many_extensive
 
             if skip_tests:
                 eprint("Skipping all extensive tests")
@@ -253,16 +307,14 @@ class Context:
         may_skip = str(self.may_skip_libm_ci()).lower()
         print(f"extensive_matrix={ext_matrix}")
         print(f"may_skip_libm_ci={may_skip}")
-        eprint(f"extensive_matrix={ext_matrix}")
-        eprint(f"may_skip_libm_ci={may_skip}")
         eprint(f"total extensive tests: {total_to_test}")
 
-        if error_on_many_tests and total_to_test > MANY_EXTENSIVE_THRESHOLD:
+        if error_on_many_tests and total_to_test > PrCfg.MANY_EXTENSIVE_THRESHOLD:
             eprint(
-                f"More than {MANY_EXTENSIVE_THRESHOLD} tests would be run; add"
-                f" `{ALLOW_MANY_EXTENSIVE_DIRECTIVE}` to the PR body if this is"
+                f"More than {PrCfg.MANY_EXTENSIVE_THRESHOLD} tests would be run; add"
+                f" `{PrCfg.DIR_ALLOW_MANY_EXTENSIVE}` to the PR body if this is"
                 " intentional. If this is refactoring that happens to touch a lot of"
-                f" files, `{SKIP_EXTENSIVE_DIRECTIVE}` can be used instead."
+                f" files, `{PrCfg.DIR_SKIP_EXTENSIVE}` can be used instead."
             )
             exit(1)
 
@@ -371,8 +423,8 @@ def handle_bench_regressions(args: list[str]):
             eprint(USAGE)
             exit(1)
 
-    pr = PrInfo.load(pr_number)
-    if pr.contains_directive(REGRESSION_DIRECTIVE):
+    pr = PrInfo.from_pr(pr_number)
+    if pr.cfg.allow_regressions:
         eprint("PR allows regressions")
         return