keyforkd: allow performing multiple requests on the same socket
This commit is contained in:
		
							parent
							
								
									a24a0166cc
								
							
						
					
					
						commit
						f1c24fb33e
					
				|  | @ -81,9 +81,6 @@ fn ed25519_test_suite() { | ||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
|                 for i in 2..chain_len { |                 for i in 2..chain_len { | ||||||
|                     // FIXME: Keyfork will only allow one request per session
 |  | ||||||
|                     let socket = UnixStream::connect(&socket_path).unwrap(); |  | ||||||
|                     let mut client = Client::new(socket); |  | ||||||
|                     // Consistency check: ensure the server and the client can each derive the same
 |                     // Consistency check: ensure the server and the client can each derive the same
 | ||||||
|                     // key using an XPrv, for all but the last XPrv, which is verified after this
 |                     // key using an XPrv, for all but the last XPrv, which is verified after this
 | ||||||
|                     let path = DerivationPath::from_str(test.chain).unwrap(); |                     let path = DerivationPath::from_str(test.chain).unwrap(); | ||||||
|  | @ -95,8 +92,6 @@ fn ed25519_test_suite() { | ||||||
|                         .fold(DerivationPath::default(), |p, i| p.chain_push(i.clone())); |                         .fold(DerivationPath::default(), |p, i| p.chain_push(i.clone())); | ||||||
|                     let xprv = dbg!(client.request_xprv::<SigningKey>(&left_path)).unwrap(); |                     let xprv = dbg!(client.request_xprv::<SigningKey>(&left_path)).unwrap(); | ||||||
|                     let derived_xprv = xprv.derive_path(&right_path).unwrap(); |                     let derived_xprv = xprv.derive_path(&right_path).unwrap(); | ||||||
|                     let socket = UnixStream::connect(&socket_path).unwrap(); |  | ||||||
|                     let mut client = Client::new(socket); |  | ||||||
|                     let keyforkd_xprv = client.request_xprv::<SigningKey>(&path).unwrap(); |                     let keyforkd_xprv = client.request_xprv::<SigningKey>(&path).unwrap(); | ||||||
|                     assert_eq!( |                     assert_eq!( | ||||||
|                         derived_xprv, keyforkd_xprv, |                         derived_xprv, keyforkd_xprv, | ||||||
|  |  | ||||||
|  | @ -68,55 +68,58 @@ impl UnixServer { | ||||||
|             #[cfg(feature = "tracing")] |             #[cfg(feature = "tracing")] | ||||||
|             debug!("new socket connected"); |             debug!("new socket connected"); | ||||||
|             tokio::spawn(async move { |             tokio::spawn(async move { | ||||||
|                 let bytes = match try_decode_from(&mut socket).await { |                 // Process requests until an error occurs or a client disconnects
 | ||||||
|                     Ok(bytes) => bytes, |                 loop { | ||||||
|                     Err(e) => { |                     let bytes = match try_decode_from(&mut socket).await { | ||||||
|                         #[cfg(feature = "tracing")] |                         Ok(bytes) => bytes, | ||||||
|                         debug!(%e, "Error reading DerivationPath from socket"); |                         Err(e) => { | ||||||
|                         let content = e.to_string().bytes().collect::<Vec<_>>(); |                             #[cfg(feature = "tracing")] | ||||||
|                         let result = try_encode_to(&content[..], &mut socket).await; |                             debug!(%e, "Error reading DerivationPath from socket"); | ||||||
|                         #[cfg(feature = "tracing")] |                             let content = e.to_string().bytes().collect::<Vec<_>>(); | ||||||
|                         if let Err(error) = result { |                             let result = try_encode_to(&content[..], &mut socket).await; | ||||||
|                             debug!(%error, "Error sending error to client"); |                             #[cfg(feature = "tracing")] | ||||||
|  |                             if let Err(error) = result { | ||||||
|  |                                 debug!(%error, "Error sending error to client"); | ||||||
|  |                             } | ||||||
|  |                             return; | ||||||
|                         } |                         } | ||||||
|                         return; |                     }; | ||||||
|                     } |  | ||||||
|                 }; |  | ||||||
| 
 | 
 | ||||||
|                 let app = match app.ready().await { |                     let app = match app.ready().await { | ||||||
|                     Ok(app) => app, |                         Ok(app) => app, | ||||||
|                     Err(e) => { |                         Err(e) => { | ||||||
|                         #[cfg(feature = "tracing")] |                             #[cfg(feature = "tracing")] | ||||||
|                         debug!(%e, "Could not poll ready"); |                             debug!(%e, "Could not poll ready"); | ||||||
|                         let content = e.to_string().bytes().collect::<Vec<_>>(); |                             let content = e.to_string().bytes().collect::<Vec<_>>(); | ||||||
|                         let result = try_encode_to(&content[..], &mut socket).await; |                             let result = try_encode_to(&content[..], &mut socket).await; | ||||||
|                         #[cfg(feature = "tracing")] |                             #[cfg(feature = "tracing")] | ||||||
|                         if let Err(error) = result { |                             if let Err(error) = result { | ||||||
|                             debug!(%error, "Error sending error to client"); |                                 debug!(%error, "Error sending error to client"); | ||||||
|  |                             } | ||||||
|  |                             return; | ||||||
|                         } |                         } | ||||||
|                         return; |                     }; | ||||||
|                     } |  | ||||||
|                 }; |  | ||||||
| 
 | 
 | ||||||
|                 let response = match app.call(bytes.into()).await { |                     let response = match app.call(bytes.into()).await { | ||||||
|                     Ok(response) => response, |                         Ok(response) => response, | ||||||
|                     Err(e) => { |                         Err(e) => { | ||||||
|                         #[cfg(feature = "tracing")] |                             #[cfg(feature = "tracing")] | ||||||
|                         debug!(%e, "Error reading DerivationPath from socket"); |                             debug!(%e, "Error reading DerivationPath from socket"); | ||||||
|                         let content = e.to_string().bytes().collect::<Vec<_>>(); |                             let content = e.to_string().bytes().collect::<Vec<_>>(); | ||||||
|                         let result = try_encode_to(&content[..], &mut socket).await; |                             let result = try_encode_to(&content[..], &mut socket).await; | ||||||
|                         #[cfg(feature = "tracing")] |                             #[cfg(feature = "tracing")] | ||||||
|                         if let Err(error) = result { |                             if let Err(error) = result { | ||||||
|                             debug!(%error, "Error sending error to client"); |                                 debug!(%error, "Error sending error to client"); | ||||||
|  |                             } | ||||||
|  |                             return; | ||||||
|                         } |                         } | ||||||
|                         return; |  | ||||||
|                     } |                     } | ||||||
|                 } |                     .into(); | ||||||
|                 .into(); |  | ||||||
| 
 | 
 | ||||||
|                 if let Err(e) = try_encode_to(&response[..], &mut socket).await { |                     if let Err(e) = try_encode_to(&response[..], &mut socket).await { | ||||||
|                     #[cfg(feature = "tracing")] |                         #[cfg(feature = "tracing")] | ||||||
|                     debug!(%e, "Error sending response to client"); |                         debug!(%e, "Error sending response to client"); | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|             }); |             }); | ||||||
|         } |         } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue