@@ -7,6 +7,7 @@ mod tasks;
7
7
8
8
use std:: net:: SocketAddr ;
9
9
10
+ use bytes:: Bytes ;
10
11
use clap:: Clap ;
11
12
use hyper:: service:: { make_service_fn, service_fn} ;
12
13
use hyper:: { body:: to_bytes, header:: HeaderValue , Body , Method , Request , Response , StatusCode } ;
@@ -22,97 +23,102 @@ use crate::metrics::Metrics;
22
23
use crate :: tasks:: { TaskCode , TaskManager } ;
23
24
24
25
const SERVER_INFO : & str = concat ! ( env!( "CARGO_PKG_NAME" ) , "/" , env!( "CARGO_PKG_VERSION" ) ) ;
25
- const NOT_FOUND : & [ u8 ] = b"Not Found" ;
26
+ const RESPONSE_DEFAULT : & [ u8 ] = b"MOSEC service" ;
27
+ const RESPONSE_NOT_FOUND : & [ u8 ] = b"not found" ;
28
+ const RESPONSE_EMPTY : & [ u8 ] = b"no data provided" ;
29
+ const RESPONSE_SHUTDOWN : & [ u8 ] = b"gracefully shutting down" ;
26
30
27
- async fn index ( _: Request < Body > ) -> Result < Response < Body > , ServiceError > {
31
+ async fn index ( _: Request < Body > ) -> Response < Body > {
28
32
let task_manager = TaskManager :: global ( ) ;
29
33
if task_manager. is_shutdown ( ) {
30
- return Err ( ServiceError :: GracefulShutdown ) ;
34
+ build_response (
35
+ StatusCode :: SERVICE_UNAVAILABLE ,
36
+ Bytes :: from_static ( RESPONSE_SHUTDOWN ) ,
37
+ )
38
+ } else {
39
+ build_response ( StatusCode :: OK , Bytes :: from_static ( RESPONSE_DEFAULT ) )
31
40
}
32
- Ok ( Response :: new ( Body :: from ( "MOSEC service" ) ) )
33
41
}
34
42
35
- async fn metrics ( _: Request < Body > ) -> Result < Response < Body > , ServiceError > {
43
+ async fn metrics ( _: Request < Body > ) -> Response < Body > {
36
44
let encoder = TextEncoder :: new ( ) ;
37
45
let metrics = prometheus:: gather ( ) ;
38
46
let mut buffer = vec ! [ ] ;
39
47
encoder. encode ( & metrics, & mut buffer) . unwrap ( ) ;
40
- Ok ( Response :: new ( Body :: from ( buffer) ) )
48
+ build_response ( StatusCode :: OK , Bytes :: from ( buffer) )
41
49
}
42
50
43
- async fn inference ( req : Request < Body > ) -> Result < Response < Body > , ServiceError > {
51
+ async fn inference ( req : Request < Body > ) -> Response < Body > {
44
52
let task_manager = TaskManager :: global ( ) ;
45
53
let data = to_bytes ( req. into_body ( ) ) . await . unwrap ( ) ;
46
54
let metrics = Metrics :: global ( ) ;
47
55
56
+ if task_manager. is_shutdown ( ) {
57
+ return build_response (
58
+ StatusCode :: SERVICE_UNAVAILABLE ,
59
+ Bytes :: from_static ( RESPONSE_SHUTDOWN ) ,
60
+ ) ;
61
+ }
62
+
48
63
if data. is_empty ( ) {
49
- return Ok ( Response :: new ( Body :: from ( "No data provided" ) ) ) ;
64
+ return build_response ( StatusCode :: OK , Bytes :: from_static ( RESPONSE_EMPTY ) ) ;
50
65
}
51
66
67
+ let ( status, content) ;
52
68
metrics. remaining_task . inc ( ) ;
53
- let task = task_manager. submit_task ( data) . await ?;
54
- match task. code {
55
- TaskCode :: Normal => {
56
- metrics. remaining_task . dec ( ) ;
57
- metrics
58
- . duration
59
- . with_label_values ( & [ "total" , "total" ] )
60
- . observe ( task. create_at . elapsed ( ) . as_secs_f64 ( ) ) ;
61
- metrics
62
- . throughput
63
- . with_label_values ( & [ StatusCode :: OK . as_str ( ) ] )
64
- . inc ( ) ;
65
- Ok ( Response :: new ( Body :: from ( task. data ) ) )
69
+ match task_manager. submit_task ( data) . await {
70
+ Ok ( task) => {
71
+ content = task. data ;
72
+ status = match task. code {
73
+ TaskCode :: Normal => {
74
+ // Record latency only for successful tasks
75
+ metrics
76
+ . duration
77
+ . with_label_values ( & [ "total" , "total" ] )
78
+ . observe ( task. create_at . elapsed ( ) . as_secs_f64 ( ) ) ;
79
+ StatusCode :: OK
80
+ }
81
+ TaskCode :: BadRequestError => StatusCode :: BAD_REQUEST ,
82
+ TaskCode :: ValidationError => StatusCode :: UNPROCESSABLE_ENTITY ,
83
+ TaskCode :: InternalError => StatusCode :: INTERNAL_SERVER_ERROR ,
84
+ }
85
+ }
86
+ Err ( err) => {
87
+ // Handle errors for which tasks cannot be retrieved
88
+ content = Bytes :: from ( err. to_string ( ) ) ;
89
+ status = match err {
90
+ ServiceError :: TooManyRequests => StatusCode :: TOO_MANY_REQUESTS ,
91
+ ServiceError :: Timeout => StatusCode :: REQUEST_TIMEOUT ,
92
+ ServiceError :: UnknownError => StatusCode :: INTERNAL_SERVER_ERROR ,
93
+ } ;
66
94
}
67
- TaskCode :: BadRequestError => Err ( ServiceError :: BadRequestError ) ,
68
- TaskCode :: ValidationError => Err ( ServiceError :: ValidationError ) ,
69
- TaskCode :: InternalError => Err ( ServiceError :: InternalError ) ,
70
- TaskCode :: UnknownError => Err ( ServiceError :: UnknownError ) ,
71
95
}
72
- }
73
-
74
- fn error_handler ( err : ServiceError ) -> Response < Body > {
75
- let status = match err {
76
- ServiceError :: Timeout => StatusCode :: REQUEST_TIMEOUT ,
77
- ServiceError :: BadRequestError => StatusCode :: BAD_REQUEST ,
78
- ServiceError :: TooManyRequests => StatusCode :: TOO_MANY_REQUESTS ,
79
- ServiceError :: ValidationError => StatusCode :: UNPROCESSABLE_ENTITY ,
80
- ServiceError :: InternalError => StatusCode :: INTERNAL_SERVER_ERROR ,
81
- ServiceError :: GracefulShutdown => StatusCode :: SERVICE_UNAVAILABLE ,
82
- ServiceError :: UnknownError => StatusCode :: NOT_IMPLEMENTED ,
83
- } ;
84
- let metrics = Metrics :: global ( ) ;
85
-
86
96
metrics. remaining_task . dec ( ) ;
87
97
metrics
88
98
. throughput
89
99
. with_label_values ( & [ status. as_str ( ) ] )
90
100
. inc ( ) ;
91
101
102
+ build_response ( status, content)
103
+ }
104
+
105
+ fn build_response ( status : StatusCode , content : Bytes ) -> Response < Body > {
92
106
Response :: builder ( )
93
107
. status ( status)
94
108
. header ( "server" , HeaderValue :: from_static ( SERVER_INFO ) )
95
- . body ( Body :: from ( err . to_string ( ) ) )
109
+ . body ( Body :: from ( content ) )
96
110
. unwrap ( )
97
111
}
98
112
99
113
async fn service_func ( req : Request < Body > ) -> Result < Response < Body > , hyper:: Error > {
100
- let res = match ( req. method ( ) , req. uri ( ) . path ( ) ) {
101
- ( & Method :: GET , "/" ) => index ( req) . await ,
102
- ( & Method :: GET , "/metrics" ) => metrics ( req) . await ,
103
- ( & Method :: POST , "/inference" ) => inference ( req) . await ,
104
- _ => Ok ( Response :: builder ( )
105
- . status ( StatusCode :: NOT_FOUND )
106
- . body ( NOT_FOUND . into ( ) )
107
- . unwrap ( ) ) ,
108
- } ;
109
- match res {
110
- Ok ( mut resp) => {
111
- resp. headers_mut ( )
112
- . insert ( "server" , HeaderValue :: from_static ( SERVER_INFO ) ) ;
113
- Ok ( resp)
114
- }
115
- Err ( err) => Ok ( error_handler ( err) ) ,
114
+ match ( req. method ( ) , req. uri ( ) . path ( ) ) {
115
+ ( & Method :: GET , "/" ) => Ok ( index ( req) . await ) ,
116
+ ( & Method :: GET , "/metrics" ) => Ok ( metrics ( req) . await ) ,
117
+ ( & Method :: POST , "/inference" ) => Ok ( inference ( req) . await ) ,
118
+ _ => Ok ( build_response (
119
+ StatusCode :: NOT_FOUND ,
120
+ Bytes :: from ( RESPONSE_NOT_FOUND ) ,
121
+ ) ) ,
116
122
}
117
123
}
118
124
0 commit comments