zbus/
handshake.rs

1use async_trait::async_trait;
2use futures_util::{future::poll_fn, StreamExt};
3#[cfg(unix)]
4use nix::unistd::Uid;
5use std::{
6    collections::VecDeque,
7    convert::{TryFrom, TryInto},
8    fmt::{self, Debug},
9    path::PathBuf,
10    str::FromStr,
11};
12use tracing::{instrument, trace};
13use zvariant::Str;
14
15use sha1::{Digest, Sha1};
16
17use xdg_home::home_dir;
18
19#[cfg(windows)]
20use crate::win32;
21use crate::{
22    file::FileLines,
23    guid::Guid,
24    raw::{Connection, Socket},
25    Error, Result,
26};
27
28/// Authentication mechanisms
29///
30/// See <https://dbus.freedesktop.org/doc/dbus-specification.html#auth-mechanisms>
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum AuthMechanism {
33    /// This is the recommended authentication mechanism on platforms where credentials can be
34    /// transferred out-of-band, in particular Unix platforms that can perform credentials-passing
35    /// over the `unix:` transport.
36    External,
37
38    /// This mechanism is designed to establish that a client has the ability to read a private
39    /// file owned by the user being authenticated.
40    Cookie,
41
42    /// Does not perform any authentication at all, and should not be accepted by message buses.
43    /// However, it might sometimes be useful for non-message-bus uses of D-Bus.
44    Anonymous,
45}
46
47/// The result of a finalized handshake
48///
49/// The result of a finalized [`ClientHandshake`] or [`ServerHandshake`]. It can be passed to
50/// [`Connection::new_authenticated`] to initialize a connection.
51///
52/// [`ClientHandshake`]: struct.ClientHandshake.html
53/// [`ServerHandshake`]: struct.ServerHandshake.html
54/// [`Connection::new_authenticated`]: ../struct.Connection.html#method.new_authenticated
55#[derive(Debug)]
56pub struct Authenticated<S> {
57    pub(crate) conn: Connection<S>,
58    /// The server Guid
59    pub(crate) server_guid: Guid,
60    /// Whether file descriptor passing has been accepted by both sides
61    #[cfg(unix)]
62    pub(crate) cap_unix_fd: bool,
63}
64
65impl<S> Authenticated<S>
66where
67    S: Socket + Unpin,
68{
69    /// Create a client-side `Authenticated` for the given `socket`.
70    pub async fn client(socket: S, mechanisms: Option<VecDeque<AuthMechanism>>) -> Result<Self> {
71        ClientHandshake::new(socket, mechanisms).perform().await
72    }
73
74    /// Create a server-side `Authenticated` for the given `socket`.
75    ///
76    /// The function takes `client_uid` on Unix only. On Windows, it takes `client_sid` instead.
77    pub async fn server(
78        socket: S,
79        guid: Guid,
80        #[cfg(unix)] client_uid: Option<u32>,
81        #[cfg(windows)] client_sid: Option<String>,
82        auth_mechanisms: Option<VecDeque<AuthMechanism>>,
83        cookie_id: Option<usize>,
84        cookie_context: CookieContext<'_>,
85    ) -> Result<Self> {
86        ServerHandshake::new(
87            socket,
88            guid,
89            #[cfg(unix)]
90            client_uid,
91            #[cfg(windows)]
92            client_sid,
93            auth_mechanisms,
94            cookie_id,
95            cookie_context,
96        )?
97        .perform()
98        .await
99    }
100}
101
102/*
103 * Client-side handshake logic
104 */
105
106#[derive(Clone, Copy, Debug, PartialEq, Eq)]
107#[allow(clippy::upper_case_acronyms)]
108enum ClientHandshakeStep {
109    Init,
110    MechanismInit,
111    WaitingForData,
112    WaitingForOK,
113    WaitingForAgreeUnixFD,
114    Done,
115}
116
117// The plain-text SASL profile authentication protocol described here:
118// <https://dbus.freedesktop.org/doc/dbus-specification.html#auth-protocol>
119//
120// These are all the known commands, which can be parsed from or serialized to text.
121#[derive(Debug)]
122#[allow(clippy::upper_case_acronyms)]
123enum Command {
124    Auth(Option<AuthMechanism>, Option<Vec<u8>>),
125    Cancel,
126    Begin,
127    Data(Option<Vec<u8>>),
128    Error(String),
129    NegotiateUnixFD,
130    Rejected(Vec<AuthMechanism>),
131    Ok(Guid),
132    AgreeUnixFD,
133}
134
135/// A representation of an in-progress handshake, client-side
136///
137/// This struct is an async-compatible representation of the initial handshake that must be
138/// performed before a D-Bus connection can be used. To use it, you should call the
139/// [`advance_handshake`] method whenever the underlying socket becomes ready (tracking the
140/// readiness itself is not managed by this abstraction) until it returns `Ok(())`, at which point
141/// you can invoke the [`try_finish`] method to get an [`Authenticated`], which can be given to
142/// [`Connection::new_authenticated`].
143///
144/// [`advance_handshake`]: struct.ClientHandshake.html#method.advance_handshake
145/// [`try_finish`]: struct.ClientHandshake.html#method.try_finish
146/// [`Authenticated`]: struct.AUthenticated.html
147/// [`Connection::new_authenticated`]: ../struct.Connection.html#method.new_authenticated
148#[derive(Debug)]
149pub struct ClientHandshake<S> {
150    common: HandshakeCommon<S>,
151    step: ClientHandshakeStep,
152}
153
154#[async_trait]
155pub trait Handshake<S> {
156    /// Perform the handshake.
157    ///
158    /// On a successful handshake, you get an `Authenticated`. If you need to send a Bus Hello,
159    /// this remains to be done.
160    async fn perform(mut self) -> Result<Authenticated<S>>;
161}
162
163impl<S: Socket> ClientHandshake<S> {
164    /// Start a handshake on this client socket
165    pub fn new(socket: S, mechanisms: Option<VecDeque<AuthMechanism>>) -> ClientHandshake<S> {
166        let mechanisms = mechanisms.unwrap_or_else(|| {
167            let mut mechanisms = VecDeque::new();
168            mechanisms.push_back(AuthMechanism::External);
169            mechanisms.push_back(AuthMechanism::Cookie);
170            mechanisms.push_back(AuthMechanism::Anonymous);
171            mechanisms
172        });
173
174        ClientHandshake {
175            common: HandshakeCommon::new(socket, mechanisms, None),
176            step: ClientHandshakeStep::Init,
177        }
178    }
179
180    fn mechanism_init(&mut self) -> Result<(ClientHandshakeStep, Command)> {
181        use ClientHandshakeStep::*;
182        let mech = self.common.mechanism()?;
183        match mech {
184            AuthMechanism::Anonymous => Ok((
185                WaitingForOK,
186                Command::Auth(Some(*mech), Some("zbus".into())),
187            )),
188            AuthMechanism::External => Ok((
189                WaitingForOK,
190                Command::Auth(Some(*mech), Some(sasl_auth_id()?.into_bytes())),
191            )),
192            AuthMechanism::Cookie => Ok((
193                WaitingForData,
194                Command::Auth(Some(*mech), Some(sasl_auth_id()?.into_bytes())),
195            )),
196        }
197    }
198
199    async fn mechanism_data(&mut self, data: Vec<u8>) -> Result<(ClientHandshakeStep, Command)> {
200        let mech = self.common.mechanism()?;
201        match mech {
202            AuthMechanism::Cookie => {
203                let context = std::str::from_utf8(&data)
204                    .map_err(|_| Error::Handshake("Cookie context was not valid UTF-8".into()))?;
205                let mut split = context.split_ascii_whitespace();
206                let context = split
207                    .next()
208                    .ok_or_else(|| Error::Handshake("Missing cookie context name".into()))?;
209                let context = Str::from(context).try_into()?;
210                let id = split
211                    .next()
212                    .ok_or_else(|| Error::Handshake("Missing cookie ID".into()))?;
213                let id = id
214                    .parse()
215                    .map_err(|e| Error::Handshake(format!("Invalid cookie ID `{id}`: {e}")))?;
216                let server_challenge = split
217                    .next()
218                    .ok_or_else(|| Error::Handshake("Missing cookie challenge".into()))?;
219
220                let cookie = Cookie::lookup(&context, id).await?.cookie;
221                let client_challenge = random_ascii(16);
222                let sec = format!("{server_challenge}:{client_challenge}:{cookie}");
223                let sha1 = hex::encode(Sha1::digest(sec));
224                let data = format!("{client_challenge} {sha1}");
225                Ok((
226                    ClientHandshakeStep::WaitingForOK,
227                    Command::Data(Some(data.into())),
228                ))
229            }
230            _ => Err(Error::Handshake("Unexpected mechanism DATA".into())),
231        }
232    }
233}
234
235fn random_ascii(len: usize) -> String {
236    use rand::{distributions::Alphanumeric, thread_rng, Rng};
237    use std::iter;
238
239    let mut rng = thread_rng();
240    iter::repeat(())
241        .map(|()| rng.sample(Alphanumeric))
242        .map(char::from)
243        .take(len)
244        .collect()
245}
246
247fn sasl_auth_id() -> Result<String> {
248    let id = {
249        #[cfg(unix)]
250        {
251            Uid::effective().to_string()
252        }
253
254        #[cfg(windows)]
255        {
256            win32::ProcessToken::open(None)?.sid()?
257        }
258    };
259
260    Ok(id)
261}
262
263#[derive(Debug)]
264struct Cookie {
265    id: usize,
266    cookie: String,
267}
268
269impl Cookie {
270    fn keyring_path() -> Result<PathBuf> {
271        let mut path = home_dir()
272            .ok_or_else(|| Error::Handshake("Failed to determine home directory".into()))?;
273        path.push(".dbus-keyrings");
274        Ok(path)
275    }
276
277    async fn read_keyring(context: &CookieContext<'_>) -> Result<Vec<Cookie>> {
278        let mut path = Cookie::keyring_path()?;
279        #[cfg(unix)]
280        {
281            use std::os::unix::fs::PermissionsExt;
282
283            let perms = crate::file::metadata(&path).await?.permissions().mode();
284            if perms & 0o066 != 0 {
285                return Err(Error::Handshake(
286                    "DBus keyring has invalid permissions".into(),
287                ));
288            }
289        }
290        #[cfg(not(unix))]
291        {
292            // FIXME: add code to check directory permissions
293        }
294        path.push(&*context.0);
295        trace!("Reading keyring {:?}", path);
296        let mut lines = FileLines::open(&path).await?.enumerate();
297        let mut cookies = vec![];
298        while let Some((n, line)) = lines.next().await {
299            let line = line?;
300            let mut split = line.split_whitespace();
301            let id = split
302                .next()
303                .ok_or_else(|| {
304                    Error::Handshake(format!(
305                        "DBus cookie `{}` missing ID at line {n}",
306                        path.display(),
307                    ))
308                })?
309                .parse()
310                .map_err(|e| {
311                    Error::Handshake(format!(
312                        "Failed to parse cookie ID in file `{}` at line {n}: {e}",
313                        path.display(),
314                    ))
315                })?;
316            let _ = split.next().ok_or_else(|| {
317                Error::Handshake(format!(
318                    "DBus cookie `{}` missing creation time at line {n}",
319                    path.display(),
320                ))
321            })?;
322            let cookie = split
323                .next()
324                .ok_or_else(|| {
325                    Error::Handshake(format!(
326                        "DBus cookie `{}` missing cookie data at line {}",
327                        path.to_str().unwrap(),
328                        n
329                    ))
330                })?
331                .to_string();
332            cookies.push(Cookie { id, cookie })
333        }
334        trace!("Loaded keyring {:?}", cookies);
335        Ok(cookies)
336    }
337
338    async fn lookup(context: &CookieContext<'_>, id: usize) -> Result<Cookie> {
339        let keyring = Self::read_keyring(context).await?;
340        keyring
341            .into_iter()
342            .find(|c| c.id == id)
343            .ok_or_else(|| Error::Handshake(format!("DBus cookie ID {id} not found")))
344    }
345
346    async fn first(context: &CookieContext<'_>) -> Result<Cookie> {
347        let keyring = Self::read_keyring(context).await?;
348        keyring
349            .into_iter()
350            .next()
351            .ok_or_else(|| Error::Handshake("No cookies available".into()))
352    }
353}
354
355#[derive(Debug)]
356pub struct CookieContext<'c>(Str<'c>);
357
358impl<'c> TryFrom<Str<'c>> for CookieContext<'c> {
359    type Error = Error;
360
361    fn try_from(value: Str<'c>) -> Result<Self> {
362        if value.is_empty() {
363            return Err(Error::Handshake("Empty cookie context".into()));
364        } else if !value.is_ascii() || value.contains(['/', '\\', ' ', '\n', '\r', '\t', '.']) {
365            return Err(Error::Handshake(
366                "Invalid characters in cookie context".into(),
367            ));
368        }
369
370        Ok(Self(value))
371    }
372}
373
374impl Default for CookieContext<'_> {
375    fn default() -> Self {
376        Self(Str::from_static("org_freedesktop_general"))
377    }
378}
379
380#[async_trait]
381impl<S: Socket> Handshake<S> for ClientHandshake<S> {
382    #[instrument(skip(self))]
383    async fn perform(mut self) -> Result<Authenticated<S>> {
384        use ClientHandshakeStep::*;
385        loop {
386            let (next_step, cmd) = match self.step {
387                Init => {
388                    trace!("Initializing");
389                    #[allow(clippy::let_and_return)]
390                    let ret = self.mechanism_init()?;
391                    // The dbus daemon on some platforms requires sending the zero byte as a
392                    // separate message with SCM_CREDS.
393                    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
394                    let written = self
395                        .common
396                        .socket
397                        .send_zero_byte()
398                        .map_err(|e| {
399                            Error::Handshake(format!(
400                                "Could not send zero byte with credentials: {}",
401                                e
402                            ))
403                        })
404                        .and_then(|n| match n {
405                            None => Err(Error::Handshake(
406                                "Could not send zero byte with credentials".to_string(),
407                            )),
408                            Some(n) => Ok(n),
409                        })?;
410
411                    // leading 0 is sent separately already for `freebsd` and `dragonfly` above.
412                    #[cfg(not(any(target_os = "freebsd", target_os = "dragonfly")))]
413                    let written = poll_fn(|cx| {
414                        self.common.socket.poll_sendmsg(
415                            cx,
416                            &[b'\0'],
417                            #[cfg(unix)]
418                            &[],
419                        )
420                    })
421                    .await?;
422
423                    if written != 1 {
424                        return Err(Error::Handshake(
425                            "Could not send zero byte with credentials".to_string(),
426                        ));
427                    }
428
429                    ret
430                }
431                MechanismInit => {
432                    trace!("Initializing auth mechanisms");
433                    self.mechanism_init()?
434                }
435                WaitingForData | WaitingForOK => {
436                    trace!("Waiting for DATA or OK from server");
437                    let reply = self.common.read_command().await?;
438                    match (self.step, reply) {
439                        (_, Command::Data(data)) => {
440                            trace!("Received DATA from server");
441                            let data = data.ok_or_else(|| {
442                                Error::Handshake("Received DATA with no data from server".into())
443                            })?;
444                            self.mechanism_data(data).await?
445                        }
446                        (_, Command::Rejected(_)) => {
447                            trace!("Received REJECT from server. Will try next auth mechanism..");
448                            self.common.mechanisms.pop_front();
449                            self.step = MechanismInit;
450                            continue;
451                        }
452                        (WaitingForOK, Command::Ok(guid)) => {
453                            trace!("Received OK from server");
454                            self.common.server_guid = Some(guid);
455                            if self.common.socket.can_pass_unix_fd() {
456                                (WaitingForAgreeUnixFD, Command::NegotiateUnixFD)
457                            } else {
458                                (Done, Command::Begin)
459                            }
460                        }
461                        (_, reply) => {
462                            return Err(Error::Handshake(format!(
463                                "Unexpected server AUTH OK reply: {reply}"
464                            )));
465                        }
466                    }
467                }
468                WaitingForAgreeUnixFD => {
469                    trace!("Waiting for Unix FD passing agreement from server");
470                    let reply = self.common.read_command().await?;
471                    match reply {
472                        Command::AgreeUnixFD => {
473                            trace!("Unix FD passing agreed by server");
474                            self.common.cap_unix_fd = true
475                        }
476                        Command::Error(_) => {
477                            trace!("Unix FD passing rejected by server");
478                            self.common.cap_unix_fd = false
479                        }
480                        _ => {
481                            return Err(Error::Handshake(format!(
482                                "Unexpected server UNIX_FD reply: {reply}"
483                            )));
484                        }
485                    }
486                    (Done, Command::Begin)
487                }
488                Done => {
489                    trace!("Handshake done");
490                    return Ok(Authenticated {
491                        conn: Connection::new(self.common.socket, self.common.recv_buffer),
492                        server_guid: self.common.server_guid.unwrap(),
493                        #[cfg(unix)]
494                        cap_unix_fd: self.common.cap_unix_fd,
495                    });
496                }
497            };
498            self.common.write_command(cmd).await?;
499            self.step = next_step;
500        }
501    }
502}
503
504/*
505 * Server-side handshake logic
506 */
507
508#[derive(Debug)]
509#[allow(clippy::upper_case_acronyms)]
510enum ServerHandshakeStep {
511    WaitingForNull,
512    WaitingForAuth,
513    WaitingForData(AuthMechanism),
514    WaitingForBegin,
515    Done,
516}
517
518/// A representation of an in-progress handshake, server-side
519///
520/// This would typically be used to implement a D-Bus broker, or in the context of a P2P connection.
521///
522/// This struct is an async-compatible representation of the initial handshake that must be
523/// performed before a D-Bus connection can be used. To use it, you should call the
524/// [`advance_handshake`] method whenever the underlying socket becomes ready (tracking the
525/// readiness itself is not managed by this abstraction) until it returns `Ok(())`, at which point
526/// you can invoke the [`try_finish`] method to get an [`Authenticated`], which can be given to
527/// [`Connection::new_authenticated`].
528///
529/// [`advance_handshake`]: struct.ServerHandshake.html#method.advance_handshake
530/// [`try_finish`]: struct.ServerHandshake.html#method.try_finish
531/// [`Authenticated`]: struct.Authenticated.html
532/// [`Connection::new_authenticated`]: ../struct.Connection.html#method.new_authenticated
533#[derive(Debug)]
534pub struct ServerHandshake<'s, S> {
535    common: HandshakeCommon<S>,
536    step: ServerHandshakeStep,
537    #[cfg(unix)]
538    client_uid: Option<u32>,
539    #[cfg(windows)]
540    client_sid: Option<String>,
541    cookie_id: Option<usize>,
542    cookie_context: CookieContext<'s>,
543}
544
545impl<'s, S: Socket> ServerHandshake<'s, S> {
546    pub fn new(
547        socket: S,
548        guid: Guid,
549        #[cfg(unix)] client_uid: Option<u32>,
550        #[cfg(windows)] client_sid: Option<String>,
551        mechanisms: Option<VecDeque<AuthMechanism>>,
552        cookie_id: Option<usize>,
553        cookie_context: CookieContext<'s>,
554    ) -> Result<ServerHandshake<'s, S>> {
555        let mechanisms = match mechanisms {
556            Some(mechanisms) => mechanisms,
557            None => {
558                let mut mechanisms = VecDeque::new();
559                mechanisms.push_back(AuthMechanism::External);
560
561                mechanisms
562            }
563        };
564
565        Ok(ServerHandshake {
566            common: HandshakeCommon::new(socket, mechanisms, Some(guid)),
567            step: ServerHandshakeStep::WaitingForNull,
568            #[cfg(unix)]
569            client_uid,
570            #[cfg(windows)]
571            client_sid,
572            cookie_id,
573            cookie_context,
574        })
575    }
576
577    async fn auth_ok(&mut self) -> Result<()> {
578        let cmd = Command::Ok(self.guid().clone());
579        trace!("Sending authentication OK");
580        self.common.write_command(cmd).await?;
581        self.step = ServerHandshakeStep::WaitingForBegin;
582
583        Ok(())
584    }
585
586    async fn check_external_auth(&mut self, sasl_id: &[u8]) -> Result<()> {
587        let auth_ok = {
588            let id = std::str::from_utf8(sasl_id)
589                .map_err(|e| Error::Handshake(format!("Invalid ID: {e}")))?;
590            #[cfg(unix)]
591            {
592                let uid = id
593                    .parse::<u32>()
594                    .map_err(|e| Error::Handshake(format!("Invalid UID: {e}")))?;
595                self.client_uid.map(|u| u == uid).unwrap_or(false)
596            }
597            #[cfg(windows)]
598            {
599                self.client_sid.as_ref().map(|u| u == id).unwrap_or(false)
600            }
601        };
602
603        if auth_ok {
604            self.auth_ok().await
605        } else {
606            self.rejected_error().await
607        }
608    }
609
610    async fn check_cookie_auth(&mut self, sasl_id: &[u8]) -> Result<()> {
611        let cookie = match self.cookie_id {
612            Some(cookie_id) => Cookie::lookup(&self.cookie_context, cookie_id).await?,
613            None => Cookie::first(&self.cookie_context).await?,
614        };
615        let id = std::str::from_utf8(sasl_id)
616            .map_err(|e| Error::Handshake(format!("Invalid ID: {e}")))?;
617        if sasl_auth_id()? != id {
618            // While the spec will make you believe that DBUS_COOKIE_SHA1 can be used to
619            // authenticate any user, it is not even possible (or correct) for the server to manage
620            // contents in random users' home directories.
621            //
622            // The dbus reference implementation also has the same limitation/behavior.
623            self.rejected_error().await?;
624            return Ok(());
625        }
626        let server_challenge = random_ascii(16);
627        let data = format!("{} {} {server_challenge}", self.cookie_context.0, cookie.id);
628        let cmd = Command::Data(Some(data.into_bytes()));
629        trace!("Sending DBUS_COOKIE_SHA1 authentication challenge");
630        self.common.write_command(cmd).await?;
631
632        let auth_data = match self.common.read_command().await? {
633            Command::Data(data) => data,
634            _ => None,
635        };
636        let auth_data = auth_data.ok_or_else(|| {
637            Error::Handshake("Expected DBUS_COOKIE_SHA1 authentication challenge response".into())
638        })?;
639        let client_auth = std::str::from_utf8(&auth_data)
640            .map_err(|e| Error::Handshake(format!("Invalid COOKIE authentication data: {e}")))?;
641        let mut split = client_auth.split_ascii_whitespace();
642        let client_challenge = split
643            .next()
644            .ok_or_else(|| Error::Handshake("Missing cookie challenge".into()))?;
645        let client_sha1 = split
646            .next()
647            .ok_or_else(|| Error::Handshake("Missing client cookie data".into()))?;
648        let sec = format!("{server_challenge}:{client_challenge}:{}", cookie.cookie);
649        let sha1 = hex::encode(Sha1::digest(sec));
650
651        if sha1 == client_sha1 {
652            self.auth_ok().await
653        } else {
654            self.rejected_error().await
655        }
656    }
657
658    async fn unsupported_command_error(&mut self) -> Result<()> {
659        let cmd = Command::Error("Unsupported command".to_string());
660        trace!("Sending authentication error");
661        self.common.write_command(cmd).await?;
662        self.step = ServerHandshakeStep::WaitingForAuth;
663
664        Ok(())
665    }
666
667    async fn rejected_error(&mut self) -> Result<()> {
668        let mechanisms = self.common.mechanisms.iter().cloned().collect();
669        let cmd = Command::Rejected(mechanisms);
670        trace!("Sending authentication error");
671        self.common.write_command(cmd).await?;
672        self.step = ServerHandshakeStep::WaitingForAuth;
673
674        Ok(())
675    }
676
677    fn guid(&self) -> &Guid {
678        // SAFETY: We know that the server GUID is set because we set it in the constructor.
679        self.common
680            .server_guid
681            .as_ref()
682            .expect("Server GUID not set")
683    }
684}
685
686#[async_trait]
687impl<S: Socket> Handshake<S> for ServerHandshake<'_, S> {
688    #[instrument(skip(self))]
689    async fn perform(mut self) -> Result<Authenticated<S>> {
690        loop {
691            match self.step {
692                ServerHandshakeStep::WaitingForNull => {
693                    trace!("Waiting for NULL");
694                    let mut buffer = [0; 1];
695                    let read =
696                        poll_fn(|cx| self.common.socket.poll_recvmsg(cx, &mut buffer)).await?;
697                    #[cfg(unix)]
698                    let read = read.0;
699                    // recvmsg cannot return anything else than Ok(1) or Err
700                    debug_assert!(read == 1);
701                    if buffer[0] != 0 {
702                        return Err(Error::Handshake(
703                            "First client byte is not NUL!".to_string(),
704                        ));
705                    }
706                    trace!("Received NULL from client");
707                    self.step = ServerHandshakeStep::WaitingForAuth;
708                }
709                ServerHandshakeStep::WaitingForAuth => {
710                    trace!("Waiting for authentication");
711                    let reply = self.common.read_command().await?;
712                    match reply {
713                        Command::Auth(mech, resp) => {
714                            let mech = mech.filter(|m| self.common.mechanisms.contains(m));
715
716                            match (mech, &resp) {
717                                (Some(mech), None) => {
718                                    trace!("Sending data request");
719                                    self.common.write_command(Command::Data(None)).await?;
720                                    self.step = ServerHandshakeStep::WaitingForData(mech);
721                                }
722                                (Some(AuthMechanism::Anonymous), Some(_)) => {
723                                    self.auth_ok().await?;
724                                }
725                                (Some(AuthMechanism::External), Some(sasl_id)) => {
726                                    self.check_external_auth(sasl_id).await?;
727                                }
728                                (Some(AuthMechanism::Cookie), Some(sasl_id)) => {
729                                    self.check_cookie_auth(sasl_id).await?;
730                                }
731                                _ => self.rejected_error().await?,
732                            }
733                        }
734                        Command::Error(_) => self.rejected_error().await?,
735                        Command::Begin => {
736                            return Err(Error::Handshake(
737                                "Received BEGIN while not authenticated".to_string(),
738                            ));
739                        }
740                        _ => self.unsupported_command_error().await?,
741                    }
742                }
743                ServerHandshakeStep::WaitingForData(mech) => {
744                    trace!("Waiting for authentication");
745                    let reply = self.common.read_command().await?;
746                    match (mech, reply) {
747                        (AuthMechanism::External, Command::Data(None)) => self.auth_ok().await?,
748                        (AuthMechanism::External, Command::Data(Some(data))) => {
749                            self.check_external_auth(&data).await?;
750                        }
751                        (AuthMechanism::Anonymous, Command::Data(_)) => self.auth_ok().await?,
752                        (_, Command::Data(_)) => self.rejected_error().await?,
753                        (_, _) => self.unsupported_command_error().await?,
754                    }
755                }
756                ServerHandshakeStep::WaitingForBegin => {
757                    trace!("Waiting for Begin command from the client");
758                    let reply = self.common.read_command().await?;
759                    match reply {
760                        Command::Begin => {
761                            trace!("Received Begin command from the client");
762                            self.step = ServerHandshakeStep::Done;
763                        }
764                        Command::Cancel | Command::Error(_) => {
765                            trace!("Received CANCEL or ERROR command from the client");
766                            self.rejected_error().await?;
767                        }
768                        #[cfg(unix)]
769                        Command::NegotiateUnixFD => {
770                            trace!("Received NEGOTIATE_UNIX_FD command from the client");
771                            self.common.cap_unix_fd = true;
772                            trace!("Sending AGREE_UNIX_FD to the client");
773                            self.common.write_command(Command::AgreeUnixFD).await?;
774                            self.step = ServerHandshakeStep::WaitingForBegin;
775                        }
776                        _ => self.unsupported_command_error().await?,
777                    }
778                }
779                ServerHandshakeStep::Done => {
780                    trace!("Handshake done");
781                    return Ok(Authenticated {
782                        conn: Connection::new(self.common.socket, self.common.recv_buffer),
783                        // SAFETY: We know that the server GUID is set because we set it in the
784                        // constructor.
785                        server_guid: self.common.server_guid.expect("Server GUID not set"),
786                        #[cfg(unix)]
787                        cap_unix_fd: self.common.cap_unix_fd,
788                    });
789                }
790            }
791        }
792    }
793}
794
795impl fmt::Display for AuthMechanism {
796    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
797        let mech = match self {
798            AuthMechanism::External => "EXTERNAL",
799            AuthMechanism::Cookie => "DBUS_COOKIE_SHA1",
800            AuthMechanism::Anonymous => "ANONYMOUS",
801        };
802        write!(f, "{mech}")
803    }
804}
805
806impl FromStr for AuthMechanism {
807    type Err = Error;
808
809    fn from_str(s: &str) -> Result<Self> {
810        match s {
811            "EXTERNAL" => Ok(AuthMechanism::External),
812            "DBUS_COOKIE_SHA1" => Ok(AuthMechanism::Cookie),
813            "ANONYMOUS" => Ok(AuthMechanism::Anonymous),
814            _ => Err(Error::Handshake(format!("Unknown mechanism: {s}"))),
815        }
816    }
817}
818
819impl From<Command> for Vec<u8> {
820    fn from(c: Command) -> Self {
821        c.to_string().into()
822    }
823}
824
825impl fmt::Display for Command {
826    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
827        match self {
828            Command::Auth(mech, resp) => match (mech, resp) {
829                (Some(mech), Some(resp)) => write!(f, "AUTH {mech} {}", hex::encode(resp)),
830                (Some(mech), None) => write!(f, "AUTH {mech}"),
831                _ => write!(f, "AUTH"),
832            },
833            Command::Cancel => write!(f, "CANCEL"),
834            Command::Begin => write!(f, "BEGIN"),
835            Command::Data(data) => match data {
836                None => write!(f, "DATA"),
837                Some(data) => write!(f, "DATA {}", hex::encode(data)),
838            },
839            Command::Error(expl) => write!(f, "ERROR {expl}"),
840            Command::NegotiateUnixFD => write!(f, "NEGOTIATE_UNIX_FD"),
841            Command::Rejected(mechs) => {
842                write!(
843                    f,
844                    "REJECTED {}",
845                    mechs
846                        .iter()
847                        .map(|m| m.to_string())
848                        .collect::<Vec<_>>()
849                        .join(" ")
850                )
851            }
852            Command::Ok(guid) => write!(f, "OK {guid}"),
853            Command::AgreeUnixFD => write!(f, "AGREE_UNIX_FD"),
854        }?;
855        write!(f, "\r\n")
856    }
857}
858
859impl From<hex::FromHexError> for Error {
860    fn from(e: hex::FromHexError) -> Self {
861        Error::Handshake(format!("Invalid hexcode: {e}"))
862    }
863}
864
865impl FromStr for Command {
866    type Err = Error;
867
868    fn from_str(s: &str) -> Result<Self> {
869        let mut words = s.split_ascii_whitespace();
870        let cmd = match words.next() {
871            Some("AUTH") => {
872                let mech = if let Some(m) = words.next() {
873                    Some(m.parse()?)
874                } else {
875                    None
876                };
877                let resp = match words.next() {
878                    Some(resp) => Some(hex::decode(resp)?),
879                    None => None,
880                };
881                Command::Auth(mech, resp)
882            }
883            Some("CANCEL") => Command::Cancel,
884            Some("BEGIN") => Command::Begin,
885            Some("DATA") => {
886                let data = match words.next() {
887                    Some(data) => Some(hex::decode(data)?),
888                    None => None,
889                };
890
891                Command::Data(data)
892            }
893            Some("ERROR") => Command::Error(s.into()),
894            Some("NEGOTIATE_UNIX_FD") => Command::NegotiateUnixFD,
895            Some("REJECTED") => {
896                let mechs = words.map(|m| m.parse()).collect::<Result<_>>()?;
897                Command::Rejected(mechs)
898            }
899            Some("OK") => {
900                let guid = words
901                    .next()
902                    .ok_or_else(|| Error::Handshake("Missing OK server GUID!".into()))?;
903                Command::Ok(guid.parse()?)
904            }
905            Some("AGREE_UNIX_FD") => Command::AgreeUnixFD,
906            _ => return Err(Error::Handshake(format!("Unknown command: {s}"))),
907        };
908        Ok(cmd)
909    }
910}
911
912// Common code for the client and server side of the handshake.
913#[derive(Debug)]
914pub struct HandshakeCommon<S> {
915    socket: S,
916    recv_buffer: Vec<u8>,
917    server_guid: Option<Guid>,
918    cap_unix_fd: bool,
919    // the current AUTH mechanism is front, ordered by priority
920    mechanisms: VecDeque<AuthMechanism>,
921}
922
923impl<S: Socket> HandshakeCommon<S> {
924    /// Start a handshake on this client socket
925    pub fn new(socket: S, mechanisms: VecDeque<AuthMechanism>, server_guid: Option<Guid>) -> Self {
926        Self {
927            socket,
928            recv_buffer: Vec::new(),
929            server_guid,
930            cap_unix_fd: false,
931            mechanisms,
932        }
933    }
934
935    #[instrument(skip(self))]
936    async fn write_command(&mut self, command: Command) -> Result<()> {
937        let mut send_buffer = Vec::<u8>::from(command);
938        while !send_buffer.is_empty() {
939            let written = poll_fn(|cx| {
940                self.socket.poll_sendmsg(
941                    cx,
942                    &send_buffer,
943                    #[cfg(unix)]
944                    &[],
945                )
946            })
947            .await?;
948            send_buffer.drain(..written);
949        }
950        Ok(())
951    }
952
953    #[instrument(skip(self))]
954    async fn read_command(&mut self) -> Result<Command> {
955        let mut cmd_end = 0;
956        loop {
957            if let Some(i) = self.recv_buffer[cmd_end..].iter().position(|b| *b == b'\n') {
958                if cmd_end + i == 0 || self.recv_buffer.get(cmd_end + i - 1) != Some(&b'\r') {
959                    return Err(Error::Handshake("Invalid line ending in handshake".into()));
960                }
961                cmd_end += i + 1;
962
963                break;
964            } else {
965                cmd_end = self.recv_buffer.len();
966            }
967
968            let mut buf = [0; 64];
969            let res = poll_fn(|cx| self.socket.poll_recvmsg(cx, &mut buf)).await?;
970            let read = {
971                #[cfg(unix)]
972                {
973                    let (read, fds) = res;
974                    if !fds.is_empty() {
975                        return Err(Error::Handshake("Unexpected FDs during handshake".into()));
976                    }
977                    read
978                }
979                #[cfg(not(unix))]
980                {
981                    res
982                }
983            };
984            if read == 0 {
985                return Err(Error::Handshake("Unexpected EOF during handshake".into()));
986            }
987            self.recv_buffer.extend(&buf[..read]);
988        }
989
990        let line_bytes = self.recv_buffer.drain(..cmd_end);
991        let line = std::str::from_utf8(line_bytes.as_slice())
992            .map_err(|e| Error::Handshake(e.to_string()))?;
993
994        line.parse()
995    }
996
997    fn mechanism(&self) -> Result<&AuthMechanism> {
998        self.mechanisms
999            .front()
1000            .ok_or_else(|| Error::Handshake("Exhausted available AUTH mechanisms".into()))
1001    }
1002}
1003
1004#[cfg(unix)]
1005#[cfg(test)]
1006mod tests {
1007    #[cfg(not(feature = "tokio"))]
1008    use async_std::io::{Write as AsyncWrite, WriteExt};
1009    use futures_util::future::join;
1010    use ntest::timeout;
1011    #[cfg(not(feature = "tokio"))]
1012    use std::os::unix::net::UnixStream;
1013    use test_log::test;
1014    #[cfg(feature = "tokio")]
1015    use tokio::{
1016        io::{AsyncWrite, AsyncWriteExt},
1017        net::UnixStream,
1018    };
1019
1020    use super::*;
1021
1022    use crate::Guid;
1023
1024    fn create_async_socket_pair() -> (impl AsyncWrite + Socket, impl AsyncWrite + Socket) {
1025        // Tokio needs us to call the sync function from async context. :shrug:
1026        let (p0, p1) = crate::utils::block_on(async { UnixStream::pair().unwrap() });
1027
1028        // initialize both handshakes
1029        #[cfg(not(feature = "tokio"))]
1030        let (p0, p1) = {
1031            p0.set_nonblocking(true).unwrap();
1032            p1.set_nonblocking(true).unwrap();
1033
1034            (
1035                async_io::Async::new(p0).unwrap(),
1036                async_io::Async::new(p1).unwrap(),
1037            )
1038        };
1039
1040        (p0, p1)
1041    }
1042
1043    #[test]
1044    fn handshake() {
1045        let (p0, p1) = create_async_socket_pair();
1046
1047        let client = ClientHandshake::new(p0, None);
1048        let server = ServerHandshake::new(
1049            p1,
1050            Guid::generate(),
1051            Some(Uid::effective().into()),
1052            None,
1053            None,
1054            CookieContext::default(),
1055        )
1056        .unwrap();
1057
1058        // proceed to the handshakes
1059        let (client, server) = crate::utils::block_on(join(
1060            async move { client.perform().await.unwrap() },
1061            async move { server.perform().await.unwrap() },
1062        ));
1063
1064        assert_eq!(client.server_guid, server.server_guid);
1065        assert_eq!(client.cap_unix_fd, server.cap_unix_fd);
1066    }
1067
1068    #[test]
1069    #[timeout(15000)]
1070    fn pipelined_handshake() {
1071        let (mut p0, p1) = create_async_socket_pair();
1072        let server = ServerHandshake::new(
1073            p1,
1074            Guid::generate(),
1075            Some(Uid::effective().into()),
1076            None,
1077            None,
1078            CookieContext::default(),
1079        )
1080        .unwrap();
1081
1082        crate::utils::block_on(
1083            p0.write_all(
1084                format!(
1085                    "\0AUTH EXTERNAL {}\r\nNEGOTIATE_UNIX_FD\r\nBEGIN\r\n",
1086                    hex::encode(sasl_auth_id().unwrap())
1087                )
1088                .as_bytes(),
1089            ),
1090        )
1091        .unwrap();
1092        let server = crate::utils::block_on(server.perform()).unwrap();
1093
1094        assert!(server.cap_unix_fd);
1095    }
1096
1097    #[test]
1098    #[timeout(15000)]
1099    fn separate_external_data() {
1100        let (mut p0, p1) = create_async_socket_pair();
1101        let server = ServerHandshake::new(
1102            p1,
1103            Guid::generate(),
1104            Some(Uid::effective().into()),
1105            None,
1106            None,
1107            CookieContext::default(),
1108        )
1109        .unwrap();
1110
1111        crate::utils::block_on(
1112            p0.write_all(
1113                format!(
1114                    "\0AUTH EXTERNAL\r\nDATA {}\r\nBEGIN\r\n",
1115                    hex::encode(sasl_auth_id().unwrap())
1116                )
1117                .as_bytes(),
1118            ),
1119        )
1120        .unwrap();
1121        crate::utils::block_on(server.perform()).unwrap();
1122    }
1123
1124    #[test]
1125    #[timeout(15000)]
1126    fn missing_external_data() {
1127        let (mut p0, p1) = create_async_socket_pair();
1128        let server = ServerHandshake::new(
1129            p1,
1130            Guid::generate(),
1131            Some(Uid::effective().into()),
1132            None,
1133            None,
1134            CookieContext::default(),
1135        )
1136        .unwrap();
1137
1138        crate::utils::block_on(p0.write_all(b"\0AUTH EXTERNAL\r\nDATA\r\nBEGIN\r\n")).unwrap();
1139        crate::utils::block_on(server.perform()).unwrap();
1140    }
1141
1142    #[test]
1143    #[timeout(15000)]
1144    fn anonymous_handshake() {
1145        let (mut p0, p1) = create_async_socket_pair();
1146        let server = ServerHandshake::new(
1147            p1,
1148            Guid::generate(),
1149            Some(Uid::effective().into()),
1150            Some(vec![AuthMechanism::Anonymous].into()),
1151            None,
1152            CookieContext::default(),
1153        )
1154        .unwrap();
1155
1156        crate::utils::block_on(p0.write_all(b"\0AUTH ANONYMOUS abcd\r\nBEGIN\r\n")).unwrap();
1157        crate::utils::block_on(server.perform()).unwrap();
1158    }
1159
1160    #[test]
1161    #[timeout(15000)]
1162    fn separate_anonymous_data() {
1163        let (mut p0, p1) = create_async_socket_pair();
1164        let server = ServerHandshake::new(
1165            p1,
1166            Guid::generate(),
1167            Some(Uid::effective().into()),
1168            Some(vec![AuthMechanism::Anonymous].into()),
1169            None,
1170            CookieContext::default(),
1171        )
1172        .unwrap();
1173
1174        crate::utils::block_on(p0.write_all(b"\0AUTH ANONYMOUS\r\nDATA abcd\r\nBEGIN\r\n"))
1175            .unwrap();
1176        crate::utils::block_on(server.perform()).unwrap();
1177    }
1178}