Skip to main content

mctp_rs/
mctp_packet_context.rs

1use crate::{
2    MctpMessage, MctpMessageHeaderTrait, MctpMessageTrait, MctpPacketError,
3    deserialize::{map_decode_err, parse_message_body, parse_transport_header},
4    endpoint_id::EndpointId,
5    error::{MctpPacketResult, ProtocolError},
6    mctp_message_tag::MctpMessageTag,
7    mctp_sequence_number::MctpSequenceNumber,
8    medium::{MctpMedium, MctpMediumFrame},
9    serialize::SerializePacketState,
10};
11
12/// Represents the state needed to construct a repsonse to a request:
13/// the MCTP transport source/destination, the sequence number to use for
14/// the reply, and the medium-specific context that came with the request.
15#[derive(Debug, PartialEq, Eq)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct MctpReplyContext<M: MctpMedium> {
18    pub destination_endpoint_id: EndpointId,
19    pub source_endpoint_id: EndpointId,
20    pub packet_sequence_number: MctpSequenceNumber,
21    pub message_tag: MctpMessageTag,
22    pub medium_context: M::ReplyContext,
23}
24
25/// Context for serializing and deserializing an MCTP message, which may be split among multiple
26/// packets.
27pub struct MctpPacketContext<'buf, M: MctpMedium> {
28    assembly_state: AssemblyState,
29    medium: M,
30    packet_assembly_buffer: &'buf mut [u8],
31}
32
33impl<'buf, M: MctpMedium> MctpPacketContext<'buf, M> {
34    pub fn new(medium: M, packet_assembly_buffer: &'buf mut [u8]) -> Self {
35        Self {
36            medium,
37            assembly_state: AssemblyState::Idle,
38            packet_assembly_buffer,
39        }
40    }
41
42    pub fn deserialize_packet(
43        &mut self,
44        packet: &[u8],
45    ) -> MctpPacketResult<Option<MctpMessage<'_, M>>, M> {
46        let (medium_frame, mut decoder) = self.medium.deserialize(packet)?;
47        let transport_header = parse_transport_header::<M>(&mut decoder)?;
48
49        let mut state = match self.assembly_state {
50            AssemblyState::Idle => {
51                if transport_header.start_of_message == 0 {
52                    return Err(MctpPacketError::ProtocolError(
53                        ProtocolError::ExpectedStartOfMessage,
54                    ));
55                }
56
57                AssemblingState {
58                    message_tag: transport_header.message_tag,
59                    tag_owner: transport_header.tag_owner,
60                    source_endpoint_id: transport_header.source_endpoint_id,
61                    packet_sequence_number: transport_header.packet_sequence_number,
62                    packet_assembly_buffer_index: 0,
63                }
64            }
65            AssemblyState::Receiving(assembling_state) => {
66                if transport_header.start_of_message != 0 {
67                    return Err(MctpPacketError::ProtocolError(
68                        ProtocolError::UnexpectedStartOfMessage,
69                    ));
70                }
71                if assembling_state.message_tag != transport_header.message_tag {
72                    return Err(MctpPacketError::ProtocolError(
73                        ProtocolError::MessageTagMismatch(
74                            assembling_state.message_tag,
75                            transport_header.message_tag,
76                        ),
77                    ));
78                }
79                if assembling_state.tag_owner != transport_header.tag_owner {
80                    return Err(MctpPacketError::ProtocolError(
81                        ProtocolError::TagOwnerMismatch(
82                            assembling_state.tag_owner,
83                            transport_header.tag_owner,
84                        ),
85                    ));
86                }
87                if assembling_state.source_endpoint_id != transport_header.source_endpoint_id {
88                    return Err(MctpPacketError::ProtocolError(
89                        ProtocolError::SourceEndpointIdMismatch(
90                            assembling_state.source_endpoint_id,
91                            transport_header.source_endpoint_id,
92                        ),
93                    ));
94                }
95                let expected_sequence_number = assembling_state.packet_sequence_number.next();
96                if expected_sequence_number != transport_header.packet_sequence_number {
97                    return Err(MctpPacketError::ProtocolError(
98                        ProtocolError::UnexpectedPacketSequenceNumber(
99                            expected_sequence_number,
100                            transport_header.packet_sequence_number,
101                        ),
102                    ));
103                }
104                assembling_state
105            }
106        };
107
108        let buffer_idx = state.packet_assembly_buffer_index;
109        let packet_size = medium_frame.packet_size();
110        if packet_size < 4 {
111            return Err(MctpPacketError::HeaderParseError(
112                "transport frame indicated packet length < 4",
113            ));
114        }
115        let packet_size = packet_size - 4; // to account for the transport header
116        // Check assembly buffer bounds (decoded bytes destination)
117        if buffer_idx + packet_size > self.packet_assembly_buffer.len() {
118            return Err(MctpPacketError::HeaderParseError(
119                "packet assembly buffer overflow - insufficient space",
120            ));
121        }
122        // Decode `packet_size` payload bytes from the (possibly stuffed) wire
123        // buffer into the assembly buffer one byte at a time via the
124        // medium-supplied decoder. We do NOT pre-check
125        // `decoder.remaining_wire() < packet_size` because for stuffing
126        // encodings wire length is not decoded length; PrematureEnd from
127        // `read()` is the canonical "ran out of bytes while decoding the
128        // body" signal.
129        for i in 0..packet_size {
130            self.packet_assembly_buffer[buffer_idx + i] = decoder.read().map_err(|e| {
131                map_decode_err::<M>(
132                    e,
133                    "packet body too short to extract expected decoded bytes",
134                    "Invalid encoding escape sequence in packet body",
135                )
136            })?;
137        }
138        state.packet_assembly_buffer_index += packet_size;
139
140        let message = if transport_header.end_of_message == 1 {
141            self.assembly_state = AssemblyState::Idle;
142            let (message_body, message_integrity_check) = parse_message_body::<M>(
143                &self.packet_assembly_buffer[..state.packet_assembly_buffer_index],
144            )?;
145            Some(MctpMessage {
146                reply_context: MctpReplyContext {
147                    destination_endpoint_id: transport_header.destination_endpoint_id,
148                    source_endpoint_id: transport_header.source_endpoint_id,
149                    packet_sequence_number: transport_header.packet_sequence_number,
150                    message_tag: transport_header.message_tag,
151                    medium_context: medium_frame.reply_context(),
152                },
153                message_buffer: message_body,
154                message_integrity_check,
155            })
156        } else {
157            self.assembly_state = AssemblyState::Receiving(state);
158            None
159        };
160
161        Ok(message)
162    }
163
164    pub fn serialize_packet<P: MctpMessageTrait<'buf>>(
165        &'buf mut self,
166        reply_context: MctpReplyContext<M>,
167        message: (P::Header, P),
168    ) -> MctpPacketResult<SerializePacketState<'buf, M>, M> {
169        match self.assembly_state {
170            AssemblyState::Idle => {}
171            _ => {
172                return Err(MctpPacketError::ProtocolError(
173                    ProtocolError::SendMessageWhileAssembling,
174                ));
175            }
176        };
177
178        self.packet_assembly_buffer[0] = P::MESSAGE_TYPE;
179        let header_size = message.0.serialize(&mut self.packet_assembly_buffer[1..])?;
180        let body_size = message
181            .1
182            .serialize(&mut self.packet_assembly_buffer[header_size + 1..])?;
183
184        let (message, rest) = self
185            .packet_assembly_buffer
186            .split_at_mut(header_size + body_size + 1);
187
188        Ok(SerializePacketState {
189            medium: &self.medium,
190            reply_context,
191            current_packet_num: 0,
192            serialized_message_header: false,
193            message_buffer: message,
194            assembly_buffer: rest,
195        })
196    }
197}
198
199#[derive(Debug, Copy, Clone, PartialEq, Eq)]
200#[cfg_attr(feature = "defmt", derive(defmt::Format))]
201enum AssemblyState {
202    Idle,
203    Receiving(AssemblingState),
204}
205
206#[derive(Debug, Copy, Clone, PartialEq, Eq)]
207#[cfg_attr(feature = "defmt", derive(defmt::Format))]
208struct AssemblingState {
209    message_tag: MctpMessageTag,
210    tag_owner: u8,
211    source_endpoint_id: EndpointId,
212    packet_sequence_number: MctpSequenceNumber,
213    packet_assembly_buffer_index: usize,
214}