1use crate::{
2 MctpMessage, MctpMessageHeaderTrait, MctpMessageTrait, MctpPacketError,
3 deserialize::{map_decode_err, parse_message_body, parse_transport_header},
4 endpoint_id::EndpointId,
5 error::{MctpPacketResult, ProtocolError},
6 mctp_message_tag::MctpMessageTag,
7 mctp_sequence_number::MctpSequenceNumber,
8 medium::{MctpMedium, MctpMediumFrame},
9 serialize::SerializePacketState,
10};
11
12#[derive(Debug, PartialEq, Eq)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct MctpReplyContext<M: MctpMedium> {
18 pub destination_endpoint_id: EndpointId,
19 pub source_endpoint_id: EndpointId,
20 pub packet_sequence_number: MctpSequenceNumber,
21 pub message_tag: MctpMessageTag,
22 pub medium_context: M::ReplyContext,
23}
24
25pub struct MctpPacketContext<'buf, M: MctpMedium> {
28 assembly_state: AssemblyState,
29 medium: M,
30 packet_assembly_buffer: &'buf mut [u8],
31}
32
33impl<'buf, M: MctpMedium> MctpPacketContext<'buf, M> {
34 pub fn new(medium: M, packet_assembly_buffer: &'buf mut [u8]) -> Self {
35 Self {
36 medium,
37 assembly_state: AssemblyState::Idle,
38 packet_assembly_buffer,
39 }
40 }
41
42 pub fn deserialize_packet(
43 &mut self,
44 packet: &[u8],
45 ) -> MctpPacketResult<Option<MctpMessage<'_, M>>, M> {
46 let (medium_frame, mut decoder) = self.medium.deserialize(packet)?;
47 let transport_header = parse_transport_header::<M>(&mut decoder)?;
48
49 let mut state = match self.assembly_state {
50 AssemblyState::Idle => {
51 if transport_header.start_of_message == 0 {
52 return Err(MctpPacketError::ProtocolError(
53 ProtocolError::ExpectedStartOfMessage,
54 ));
55 }
56
57 AssemblingState {
58 message_tag: transport_header.message_tag,
59 tag_owner: transport_header.tag_owner,
60 source_endpoint_id: transport_header.source_endpoint_id,
61 packet_sequence_number: transport_header.packet_sequence_number,
62 packet_assembly_buffer_index: 0,
63 }
64 }
65 AssemblyState::Receiving(assembling_state) => {
66 if transport_header.start_of_message != 0 {
67 return Err(MctpPacketError::ProtocolError(
68 ProtocolError::UnexpectedStartOfMessage,
69 ));
70 }
71 if assembling_state.message_tag != transport_header.message_tag {
72 return Err(MctpPacketError::ProtocolError(
73 ProtocolError::MessageTagMismatch(
74 assembling_state.message_tag,
75 transport_header.message_tag,
76 ),
77 ));
78 }
79 if assembling_state.tag_owner != transport_header.tag_owner {
80 return Err(MctpPacketError::ProtocolError(
81 ProtocolError::TagOwnerMismatch(
82 assembling_state.tag_owner,
83 transport_header.tag_owner,
84 ),
85 ));
86 }
87 if assembling_state.source_endpoint_id != transport_header.source_endpoint_id {
88 return Err(MctpPacketError::ProtocolError(
89 ProtocolError::SourceEndpointIdMismatch(
90 assembling_state.source_endpoint_id,
91 transport_header.source_endpoint_id,
92 ),
93 ));
94 }
95 let expected_sequence_number = assembling_state.packet_sequence_number.next();
96 if expected_sequence_number != transport_header.packet_sequence_number {
97 return Err(MctpPacketError::ProtocolError(
98 ProtocolError::UnexpectedPacketSequenceNumber(
99 expected_sequence_number,
100 transport_header.packet_sequence_number,
101 ),
102 ));
103 }
104 assembling_state
105 }
106 };
107
108 let buffer_idx = state.packet_assembly_buffer_index;
109 let packet_size = medium_frame.packet_size();
110 if packet_size < 4 {
111 return Err(MctpPacketError::HeaderParseError(
112 "transport frame indicated packet length < 4",
113 ));
114 }
115 let packet_size = packet_size - 4; if buffer_idx + packet_size > self.packet_assembly_buffer.len() {
118 return Err(MctpPacketError::HeaderParseError(
119 "packet assembly buffer overflow - insufficient space",
120 ));
121 }
122 for i in 0..packet_size {
130 self.packet_assembly_buffer[buffer_idx + i] = decoder.read().map_err(|e| {
131 map_decode_err::<M>(
132 e,
133 "packet body too short to extract expected decoded bytes",
134 "Invalid encoding escape sequence in packet body",
135 )
136 })?;
137 }
138 state.packet_assembly_buffer_index += packet_size;
139
140 let message = if transport_header.end_of_message == 1 {
141 self.assembly_state = AssemblyState::Idle;
142 let (message_body, message_integrity_check) = parse_message_body::<M>(
143 &self.packet_assembly_buffer[..state.packet_assembly_buffer_index],
144 )?;
145 Some(MctpMessage {
146 reply_context: MctpReplyContext {
147 destination_endpoint_id: transport_header.destination_endpoint_id,
148 source_endpoint_id: transport_header.source_endpoint_id,
149 packet_sequence_number: transport_header.packet_sequence_number,
150 message_tag: transport_header.message_tag,
151 medium_context: medium_frame.reply_context(),
152 },
153 message_buffer: message_body,
154 message_integrity_check,
155 })
156 } else {
157 self.assembly_state = AssemblyState::Receiving(state);
158 None
159 };
160
161 Ok(message)
162 }
163
164 pub fn serialize_packet<P: MctpMessageTrait<'buf>>(
165 &'buf mut self,
166 reply_context: MctpReplyContext<M>,
167 message: (P::Header, P),
168 ) -> MctpPacketResult<SerializePacketState<'buf, M>, M> {
169 match self.assembly_state {
170 AssemblyState::Idle => {}
171 _ => {
172 return Err(MctpPacketError::ProtocolError(
173 ProtocolError::SendMessageWhileAssembling,
174 ));
175 }
176 };
177
178 self.packet_assembly_buffer[0] = P::MESSAGE_TYPE;
179 let header_size = message.0.serialize(&mut self.packet_assembly_buffer[1..])?;
180 let body_size = message
181 .1
182 .serialize(&mut self.packet_assembly_buffer[header_size + 1..])?;
183
184 let (message, rest) = self
185 .packet_assembly_buffer
186 .split_at_mut(header_size + body_size + 1);
187
188 Ok(SerializePacketState {
189 medium: &self.medium,
190 reply_context,
191 current_packet_num: 0,
192 serialized_message_header: false,
193 message_buffer: message,
194 assembly_buffer: rest,
195 })
196 }
197}
198
199#[derive(Debug, Copy, Clone, PartialEq, Eq)]
200#[cfg_attr(feature = "defmt", derive(defmt::Format))]
201enum AssemblyState {
202 Idle,
203 Receiving(AssemblingState),
204}
205
206#[derive(Debug, Copy, Clone, PartialEq, Eq)]
207#[cfg_attr(feature = "defmt", derive(defmt::Format))]
208struct AssemblingState {
209 message_tag: MctpMessageTag,
210 tag_owner: u8,
211 source_endpoint_id: EndpointId,
212 packet_sequence_number: MctpSequenceNumber,
213 packet_assembly_buffer_index: usize,
214}