1 use crate::load_protos;
2 use crate::{Flag, FlagSource};
3 use crate::{FlagPermission, FlagValue, ValuePickedFrom};
4 use aconfigd_protos::{
5     ProtoFlagQueryReturnMessage, ProtoListStorageMessage, ProtoListStorageMessageMsg,
6     ProtoStorageRequestMessage, ProtoStorageRequestMessageMsg, ProtoStorageRequestMessages,
7     ProtoStorageReturnMessage, ProtoStorageReturnMessageMsg, ProtoStorageReturnMessages,
8 };
9 use anyhow::anyhow;
10 use anyhow::Result;
11 use protobuf::Message;
12 use protobuf::SpecialFields;
13 use std::collections::HashMap;
14 use std::io::{Read, Write};
15 use std::net::Shutdown;
16 use std::os::unix::net::UnixStream;
17 
18 pub struct AconfigStorageSource {}
19 
load_flag_to_container() -> Result<HashMap<String, String>>20 fn load_flag_to_container() -> Result<HashMap<String, String>> {
21     Ok(load_protos::load()?.into_iter().map(|p| (p.qualified_name(), p.container)).collect())
22 }
23 
convert(msg: ProtoFlagQueryReturnMessage, containers: &HashMap<String, String>) -> Result<Flag>24 fn convert(msg: ProtoFlagQueryReturnMessage, containers: &HashMap<String, String>) -> Result<Flag> {
25     let (value, value_picked_from) = match (
26         &msg.boot_flag_value,
27         msg.default_flag_value,
28         msg.local_flag_value,
29         msg.has_local_override,
30     ) {
31         (_, _, Some(local), Some(has_local)) if has_local => {
32             (FlagValue::try_from(local.as_str())?, ValuePickedFrom::Local)
33         }
34         (Some(boot), Some(default), _, _) => {
35             let value = FlagValue::try_from(boot.as_str())?;
36             if *boot == default {
37                 (value, ValuePickedFrom::Default)
38             } else {
39                 (value, ValuePickedFrom::Server)
40             }
41         }
42         _ => return Err(anyhow!("missing override")),
43     };
44 
45     let staged_value = match (msg.boot_flag_value, msg.server_flag_value, msg.has_server_override) {
46         (Some(boot), Some(server), _) if boot == server => None,
47         (Some(boot), Some(server), Some(has_server)) if boot != server && has_server => {
48             Some(FlagValue::try_from(server.as_str())?)
49         }
50         _ => None,
51     };
52 
53     let permission = match msg.is_readwrite {
54         Some(is_readwrite) => {
55             if is_readwrite {
56                 FlagPermission::ReadWrite
57             } else {
58                 FlagPermission::ReadOnly
59             }
60         }
61         None => return Err(anyhow!("missing permission")),
62     };
63 
64     let name = msg.flag_name.ok_or(anyhow!("missing flag name"))?;
65     let package = msg.package_name.ok_or(anyhow!("missing package name"))?;
66     let qualified_name = format!("{package}.{name}");
67     Ok(Flag {
68         name,
69         package,
70         value,
71         permission,
72         value_picked_from,
73         staged_value,
74         container: containers
75             .get(&qualified_name)
76             .cloned()
77             .unwrap_or_else(|| "<no container>".to_string())
78             .to_string(),
79         // TODO: remove once DeviceConfig is not in the CLI.
80         namespace: "-".to_string(),
81     })
82 }
83 
read_from_socket() -> Result<Vec<ProtoFlagQueryReturnMessage>>84 fn read_from_socket() -> Result<Vec<ProtoFlagQueryReturnMessage>> {
85     let messages = ProtoStorageRequestMessages {
86         msgs: vec![ProtoStorageRequestMessage {
87             msg: Some(ProtoStorageRequestMessageMsg::ListStorageMessage(ProtoListStorageMessage {
88                 msg: Some(ProtoListStorageMessageMsg::All(true)),
89                 special_fields: SpecialFields::new(),
90             })),
91             special_fields: SpecialFields::new(),
92         }],
93         special_fields: SpecialFields::new(),
94     };
95 
96     let socket_name = "/dev/socket/aconfigd_system";
97     let mut socket = UnixStream::connect(socket_name)?;
98 
99     let message_buffer = messages.write_to_bytes()?;
100     let mut message_length_buffer: [u8; 4] = [0; 4];
101     let message_size = &message_buffer.len();
102     message_length_buffer[0] = (message_size >> 24) as u8;
103     message_length_buffer[1] = (message_size >> 16) as u8;
104     message_length_buffer[2] = (message_size >> 8) as u8;
105     message_length_buffer[3] = *message_size as u8;
106     socket.write_all(&message_length_buffer)?;
107     socket.write_all(&message_buffer)?;
108     socket.shutdown(Shutdown::Write)?;
109 
110     let mut response_length_buffer: [u8; 4] = [0; 4];
111     socket.read_exact(&mut response_length_buffer)?;
112     let response_length = u32::from_be_bytes(response_length_buffer) as usize;
113     let mut response_buffer = vec![0; response_length];
114     socket.read_exact(&mut response_buffer)?;
115 
116     let response: ProtoStorageReturnMessages =
117         protobuf::Message::parse_from_bytes(&response_buffer)?;
118 
119     match response.msgs.as_slice() {
120         [ProtoStorageReturnMessage {
121             msg: Some(ProtoStorageReturnMessageMsg::ListStorageMessage(list_storage_message)),
122             ..
123         }] => Ok(list_storage_message.flags.clone()),
124         _ => Err(anyhow!("unexpected response from aconfigd")),
125     }
126 }
127 
128 impl FlagSource for AconfigStorageSource {
list_flags() -> Result<Vec<Flag>>129     fn list_flags() -> Result<Vec<Flag>> {
130         let containers = load_flag_to_container()?;
131         read_from_socket()
132             .map(|query_messages| {
133                 query_messages
134                     .iter()
135                     .map(|message| convert(message.clone(), &containers))
136                     .collect::<Vec<_>>()
137             })?
138             .into_iter()
139             .collect()
140     }
141 
override_flag(_namespace: &str, _qualified_name: &str, _value: &str) -> Result<()>142     fn override_flag(_namespace: &str, _qualified_name: &str, _value: &str) -> Result<()> {
143         todo!()
144     }
145 }
146