about summary refs log tree commit diff
path: root/src/ci/github-actions
diff options
context:
space:
mode:
authorJakub Beránek <berykubik@gmail.com>2024-05-02 21:42:36 +0200
committerJakub Beránek <berykubik@gmail.com>2024-05-05 14:06:23 +0200
commit6778ecf960693b834aadca37239ce6cfa268dc55 (patch)
tree40510ce0f3241ac17f606b5c633c6f6f9b4bbfdc /src/ci/github-actions
parent02f7806ecd641d67c8f046b073323c7e176ee6d2 (diff)
downloadrust-6778ecf960693b834aadca37239ce6cfa268dc55.tar.gz
rust-6778ecf960693b834aadca37239ce6cfa268dc55.zip
Use sum type for `WorkflowRunType`
Diffstat (limited to 'src/ci/github-actions')
-rwxr-xr-xsrc/ci/github-actions/calculate-job-matrix.py40
1 files changed, 26 insertions, 14 deletions
diff --git a/src/ci/github-actions/calculate-job-matrix.py b/src/ci/github-actions/calculate-job-matrix.py
index 68565f489c9..1801904d1e7 100755
--- a/src/ci/github-actions/calculate-job-matrix.py
+++ b/src/ci/github-actions/calculate-job-matrix.py
@@ -8,10 +8,10 @@ It reads job definitions from `src/ci/github-actions/jobs.yml`
 and filters them based on the event that happened on CI.
 """
 import dataclasses
-import enum
 import json
 import logging
 import os
+import typing
 from pathlib import Path
 from typing import List, Dict, Any, Optional
 
@@ -44,10 +44,22 @@ def add_base_env(jobs: List[Job], environment: Dict[str, str]) -> List[Job]:
     return jobs
 
 
-class WorkflowRunType(enum.Enum):
-    PR = enum.auto()
-    Try = enum.auto()
-    Auto = enum.auto()
+@dataclasses.dataclass
+class PRRunType:
+    pass
+
+
+@dataclasses.dataclass
+class TryRunType:
+    custom_jobs: List[str]
+
+
+@dataclasses.dataclass
+class AutoRunType:
+    pass
+
+
+WorkflowRunType = typing.Union[PRRunType, TryRunType, AutoRunType]
 
 
 @dataclasses.dataclass
@@ -59,7 +71,7 @@ class GitHubCtx:
 
 def find_run_type(ctx: GitHubCtx) -> Optional[WorkflowRunType]:
     if ctx.event_name == "pull_request":
-        return WorkflowRunType.PR
+        return PRRunType()
     elif ctx.event_name == "push":
         old_bors_try_build = (
             ctx.ref in ("refs/heads/try", "refs/heads/try-perf") and
@@ -72,20 +84,20 @@ def find_run_type(ctx: GitHubCtx) -> Optional[WorkflowRunType]:
         try_build = old_bors_try_build or new_bors_try_build
 
         if try_build:
-            return WorkflowRunType.Try
+            return TryRunType()
 
         if ctx.ref == "refs/heads/auto" and ctx.repository == "rust-lang-ci/rust":
-            return WorkflowRunType.Auto
+            return AutoRunType()
 
     return None
 
 
 def calculate_jobs(run_type: WorkflowRunType, job_data: Dict[str, Any]) -> List[Job]:
-    if run_type == WorkflowRunType.PR:
+    if isinstance(run_type, PRRunType):
         return add_base_env(name_jobs(job_data["pr"], "PR"), job_data["envs"]["pr"])
-    elif run_type == WorkflowRunType.Try:
+    elif isinstance(run_type, TryRunType):
         return add_base_env(name_jobs(job_data["try"], "try"), job_data["envs"]["try"])
-    elif run_type == WorkflowRunType.Auto:
+    elif isinstance(run_type, AutoRunType):
         return add_base_env(name_jobs(job_data["auto"], "auto"), job_data["envs"]["auto"])
 
     return []
@@ -107,11 +119,11 @@ def get_github_ctx() -> GitHubCtx:
 
 
 def format_run_type(run_type: WorkflowRunType) -> str:
-    if run_type == WorkflowRunType.PR:
+    if isinstance(run_type, PRRunType):
         return "pr"
-    elif run_type == WorkflowRunType.Auto:
+    elif isinstance(run_type, AutoRunType):
         return "auto"
-    elif run_type == WorkflowRunType.Try:
+    elif isinstance(run_type, TryRunType):
         return "try"
     else:
         raise AssertionError()