Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1693,10 +1693,14 @@ enum PolicyCommands {
#[arg(long = "rev", default_value_t = 0)]
rev: u32,

/// Include the full policy payload.
#[arg(long)]
/// Include the effective policy payload, including provider-composed entries.
#[arg(long, conflicts_with = "base")]
full: bool,

/// Include the base policy payload without provider-composed entries.
#[arg(long)]
base: bool,

/// Output format.
#[arg(short = 'o', long = "output", value_enum, default_value_t = PolicyGetOutput::Table)]
output: PolicyGetOutput,
Expand Down Expand Up @@ -2378,14 +2382,16 @@ async fn main() -> Result<()> {
name,
rev,
full,
base,
output,
global,
} => {
let view = run::PolicyGetView::from_flags(base, full);
if global {
run::sandbox_policy_get_global(
&ctx.endpoint,
rev,
full,
view,
output.as_str(),
&tls,
)
Expand All @@ -2396,7 +2402,7 @@ async fn main() -> Result<()> {
&ctx.endpoint,
&name,
rev,
full,
view,
output.as_str(),
&tls,
)
Expand Down Expand Up @@ -4364,17 +4370,42 @@ mod tests {
Some(Commands::Policy {
command:
Some(PolicyCommands::Get {
name, full, output, ..
name,
full,
base,
output,
..
}),
}) => {
assert_eq!(name.as_deref(), Some("my-sandbox"));
assert!(full);
assert!(!base);
assert!(matches!(output, PolicyGetOutput::Json));
}
other => panic!("expected policy get command, got: {other:?}"),
}
}

#[test]
fn policy_get_base_output_parses() {
let cli = Cli::try_parse_from(["openshell", "policy", "get", "my-sandbox", "--base"])
.expect("policy get --base should parse");

match cli.command {
Some(Commands::Policy {
command:
Some(PolicyCommands::Get {
name, full, base, ..
}),
}) => {
assert_eq!(name.as_deref(), Some("my-sandbox"));
assert!(!full);
assert!(base);
}
other => panic!("expected policy get command, got: {other:?}"),
}
}

#[test]
fn policy_delete_global_parses() {
let cli = Cli::try_parse_from(["openshell", "policy", "delete", "--global", "--yes"])
Expand Down
79 changes: 60 additions & 19 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use openshell_providers::{
profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml,
};
use owo_colors::OwoColorize;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::io::{ErrorKind, IsTerminal, Read, Write};
use std::path::{Path, PathBuf};
Expand All @@ -80,6 +81,27 @@ pub use openshell_core::forward::{
find_forward_by_port, list_forwards, stop_forward, stop_forwards_for_sandbox,
};

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PolicyGetView {
Metadata,
Base,
Full,
}

impl PolicyGetView {
pub fn from_flags(base: bool, full: bool) -> Self {
match (base, full) {
(true, _) => Self::Base,
(false, true) => Self::Full,
(false, false) => Self::Metadata,
}
}

fn includes_policy(self) -> bool {
matches!(self, Self::Base | Self::Full)
}
}

#[derive(Debug, PartialEq, Eq)]
enum SandboxUploadPlan {
GitAware {
Expand Down Expand Up @@ -6867,7 +6889,7 @@ pub async fn sandbox_policy_get(
server: &str,
name: &str,
version: u32,
full: bool,
view: PolicyGetView,
output: &str,
tls: &TlsOptions,
) -> Result<()> {
Expand All @@ -6877,7 +6899,7 @@ pub async fn sandbox_policy_get(
server,
name,
version,
full,
view,
output,
tls,
(&mut stdout, &mut stderr),
Expand All @@ -6901,7 +6923,7 @@ pub async fn sandbox_policy_get_to_writer<W, E>(
server: &str,
name: &str,
version: u32,
full: bool,
view: PolicyGetView,
output: &str,
tls: &TlsOptions,
writers: (&mut W, &mut E),
Expand All @@ -6911,7 +6933,7 @@ where
E: Write + Send,
{
if version == 0 {
return sandbox_policy_get_effective_to_writer(server, name, full, output, tls, writers)
return sandbox_policy_get_effective_to_writer(server, name, view, output, tls, writers)
.await;
}

Expand All @@ -6938,7 +6960,7 @@ where
Some(inner.active_version),
&rev,
status,
full,
view,
)?;
writeln!(
stdout,
Expand Down Expand Up @@ -6966,10 +6988,11 @@ where
writeln!(stdout, "Error: {}", rev.load_error).into_diagnostic()?;
}

if full {
if view.includes_policy() {
if let Some(ref policy) = rev.policy {
writeln!(stdout, "---").into_diagnostic()?;
let yaml_str = openshell_policy::serialize_sandbox_policy(policy)
let policy = policy_for_view(policy, view);
let yaml_str = openshell_policy::serialize_sandbox_policy(policy.as_ref())
.wrap_err("failed to serialize policy to YAML")?;
write!(stdout, "{yaml_str}").into_diagnostic()?;
} else {
Expand All @@ -6987,7 +7010,7 @@ where
async fn sandbox_policy_get_effective_to_writer<W, E>(
server: &str,
name: &str,
full: bool,
view: PolicyGetView,
output: &str,
tls: &TlsOptions,
writers: (&mut W, &mut E),
Expand Down Expand Up @@ -7060,10 +7083,11 @@ where
serde_json::json!(config.global_policy_version),
);
}
if full {
if view.includes_policy() {
let policy = policy_for_view(policy, view);
obj.insert(
"policy".to_string(),
openshell_policy::sandbox_policy_to_json_value(policy)?,
openshell_policy::sandbox_policy_to_json_value(policy.as_ref())?,
);
}
writeln!(
Expand All @@ -7083,9 +7107,10 @@ where
writeln!(stdout, "Global: {}", config.global_policy_version)
.into_diagnostic()?;
}
if full {
if view.includes_policy() {
writeln!(stdout, "---").into_diagnostic()?;
let yaml_str = openshell_policy::serialize_sandbox_policy(policy)
let policy = policy_for_view(policy, view);
let yaml_str = openshell_policy::serialize_sandbox_policy(policy.as_ref())
.wrap_err("failed to serialize policy to YAML")?;
write!(stdout, "{yaml_str}").into_diagnostic()?;
}
Expand All @@ -7099,7 +7124,7 @@ where
pub async fn sandbox_policy_get_global(
server: &str,
version: u32,
full: bool,
view: PolicyGetView,
output: &str,
tls: &TlsOptions,
) -> Result<()> {
Expand All @@ -7119,7 +7144,7 @@ pub async fn sandbox_policy_get_global(
let status = PolicyStatus::try_from(rev.status).unwrap_or(PolicyStatus::Unspecified);
match output {
"json" => {
let obj = policy_revision_to_json("global", None, None, &rev, status, full)?;
let obj = policy_revision_to_json("global", None, None, &rev, status, view)?;
println!("{}", serde_json::to_string_pretty(&obj).into_diagnostic()?);
return Ok(());
}
Expand All @@ -7138,10 +7163,11 @@ pub async fn sandbox_policy_get_global(
println!("Loaded: {} ms", rev.loaded_at_ms);
}

if full {
if view.includes_policy() {
if let Some(ref policy) = rev.policy {
println!("---");
let yaml_str = openshell_policy::serialize_sandbox_policy(policy)
let policy = policy_for_view(policy, view);
let yaml_str = openshell_policy::serialize_sandbox_policy(policy.as_ref())
.wrap_err("failed to serialize policy to YAML")?;
print!("{yaml_str}");
} else {
Expand Down Expand Up @@ -7171,7 +7197,7 @@ fn policy_revision_to_json(
active_version: Option<u32>,
rev: &openshell_core::proto::SandboxPolicyRevision,
status: PolicyStatus,
full: bool,
view: PolicyGetView,
) -> Result<serde_json::Value> {
let mut obj = serde_json::Map::new();
obj.insert("scope".to_string(), serde_json::json!(scope));
Expand Down Expand Up @@ -7205,16 +7231,31 @@ fn policy_revision_to_json(
if !rev.load_error.is_empty() {
obj.insert("load_error".to_string(), serde_json::json!(rev.load_error));
}
if full {
if view.includes_policy() {
let policy = match rev.policy.as_ref() {
Some(policy) => openshell_policy::sandbox_policy_to_json_value(policy)?,
Some(policy) => {
let policy = policy_for_view(policy, view);
openshell_policy::sandbox_policy_to_json_value(policy.as_ref())?
}
None => serde_json::Value::Null,
};
obj.insert("policy".to_string(), policy);
}
Ok(serde_json::Value::Object(obj))
}

fn policy_for_view(policy: &SandboxPolicy, view: PolicyGetView) -> Cow<'_, SandboxPolicy> {
if view != PolicyGetView::Base {
return Cow::Borrowed(policy);
}

let mut base_policy = policy.clone();
base_policy
.network_policies
.retain(|name, _| !openshell_policy::is_provider_rule_name(name));
Cow::Owned(base_policy)
}

pub async fn sandbox_policy_list(
server: &str,
name: &str,
Expand Down
Loading
Loading