Skip to content

Commit 1627a9b

Browse files
lkevinzczclzc
and
zclzc
authored
Custom err msg (#76)
* support custom 422 msg by raising ValidationError * fix grammar * clean rs err handling logic * minor naming change Co-authored-by: zclzc <[email protected]>
1 parent e01fcaf commit 1627a9b

12 files changed

+127
-103
lines changed

Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dev:
1313
cp ./target/debug/mosec mosec/bin/
1414
pip install -e .
1515

16-
test:
16+
test: dev
1717
pytest tests -vv -s
1818
RUST_BACKTRACE=1 cargo test -vv
1919

@@ -44,5 +44,6 @@ lint:
4444
flake8 ${PY_SOURCE_FILES} --count --show-source --statistics
4545
mypy --install-types --non-interactive ${PY_SOURCE_FILES}
4646
cargo +nightly fmt -- --check
47+
cargo clippy
4748

4849
.PHONY: test doc

examples/echo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def forward(self, data: dict) -> float:
2121
try:
2222
time = float(data["time"])
2323
except KeyError as err:
24-
raise ValidationError(err)
24+
raise ValidationError(f"cannot find key {err}")
2525
return time
2626

2727

examples/resnet50_server_pytorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def forward(self, req: dict) -> np.ndarray:
3232
im = np.frombuffer(base64.b64decode(image), np.uint8)
3333
im = cv2.imdecode(im, cv2.IMREAD_COLOR)[:, :, ::-1] # bgr -> rgb
3434
except KeyError as err:
35-
raise ValidationError(f"bad request: {err}")
35+
raise ValidationError(f"cannot find key {err}")
3636
except Exception as err:
3737
raise ValidationError(f"cannot decode as image data: {err}")
3838

mosec/coordinator.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from multiprocessing.synchronize import Event
99
from typing import Callable, Type
1010

11-
from .errors import ValidationError
11+
from .errors import DecodingError, ValidationError
1212
from .protocol import Protocol
1313
from .worker import Worker
1414

@@ -160,15 +160,24 @@ def coordinate(self):
160160
"returned data doesn't match the input data:"
161161
f"input({len(data)})!=output({len(payloads)})"
162162
)
163-
except ValidationError as err:
163+
except DecodingError as err:
164164
err_msg = str(err).replace("\n", " - ")
165+
err_msg = (
166+
err_msg if len(err_msg) else "cannot deserialize request bytes"
167+
)
168+
logger.info(f"{self.name} decoding error: {err_msg}")
169+
status = self.protocol.FLAG_BAD_REQUEST
170+
payloads = (f"decoding error: {err_msg}".encode(),)
171+
except ValidationError as err:
172+
err_msg = str(err)
173+
err_msg = err_msg if len(err_msg) else "invalid data format"
165174
logger.info(f"{self.name} validation error: {err_msg}")
166175
status = self.protocol.FLAG_VALIDATION_ERROR
167-
payloads = (f"Validation Error: {err_msg}".encode(),)
176+
payloads = (f"validation error: {err_msg}".encode(),)
168177
except Exception:
169178
logger.warning(traceback.format_exc().replace("\n", " "))
170179
status = self.protocol.FLAG_INTERNAL_ERROR
171-
payloads = ("Internal Error".encode(),)
180+
payloads = ("inference internal error".encode(),)
172181

173182
try:
174183
self.protocol.send(status, ids, payloads)

mosec/errors.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,31 @@
1+
"""
2+
Suppose the input dataflow of our model server is as follows:
3+
4+
**bytes** --- *deserialize*<sup>(decoding)</sup> ---> **data**
5+
--- *parse*<sup>(validation)</sup> ---> **valid data**
6+
7+
If the raw bytes cannot be successfully deserialized, the `DecodingError`
8+
is raised; if the decoded data cannot pass the validation check (usually
9+
implemented by users), the `ValidationError` should be raised.
10+
"""
11+
12+
13+
class DecodingError(Exception):
14+
"""
15+
The `DecodingError` should be raised in user-implemented codes
16+
when the de-serialization for the request bytes fails. This error
17+
will set the status code to
18+
[HTTP 400]("https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400)
19+
in the response.
20+
"""
21+
22+
123
class ValidationError(Exception):
224
"""
325
The `ValidationError` should be raised in user-implemented codes,
4-
where the validation for the input data fails. Usually it can be put
5-
after the data deserialization, which converts the raw bytes into
6-
structured data.
26+
where the validation for the input data fails. Usually, it should be
27+
put after the data de-serialization, which converts the raw bytes
28+
into structured data. This error will set the status code to
29+
[HTTP 422](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/422)
30+
in the response.
731
"""

mosec/worker.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pickle
44
from typing import Any
55

6-
from .errors import ValidationError
6+
from .errors import DecodingError
77

88
logger = logging.getLogger(__name__)
99

@@ -89,7 +89,7 @@ def deserialize(self, data: bytes) -> Any:
8989
try:
9090
data_json = json.loads(data) if data else {}
9191
except Exception as err:
92-
raise ValidationError(err)
92+
raise DecodingError(err)
9393
return data_json
9494

9595
def forward(self, data: Any) -> Any:

src/coordinator.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ impl Coordinator {
2727
let (sender, receiver) = bounded(opts.capacity);
2828
let timeout = Duration::from_millis(opts.timeout);
2929
let wait_time = Duration::from_millis(opts.wait);
30-
let path = if opts.path.len() > 0 {
30+
let path = if !opts.path.is_empty() {
3131
opts.path.to_string()
3232
} else {
3333
// default IPC path
@@ -44,7 +44,7 @@ impl Coordinator {
4444

4545
Self {
4646
capacity: opts.capacity,
47-
path: path,
47+
path,
4848
batches: opts.batches.clone(),
4949
wait_time,
5050
timeout,

src/errors.rs

+1-13
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,9 @@ pub(crate) enum ServiceError {
55
#[display(fmt = "inference timeout")]
66
Timeout,
77

8-
#[display(fmt = "bad request")]
9-
BadRequestError,
10-
11-
#[display(fmt = "bad request: validation error")]
12-
ValidationError,
13-
14-
#[display(fmt = "inference internal error")]
15-
InternalError,
16-
17-
#[display(fmt = "too many request: channel is full")]
8+
#[display(fmt = "too many request: task queue is full")]
189
TooManyRequests,
1910

20-
#[display(fmt = "cannot accept new request during the graceful shutdown")]
21-
GracefulShutdown,
22-
2311
#[display(fmt = "mosec unknown error")]
2412
UnknownError,
2513
}

src/main.rs

+62-56
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod tasks;
77

88
use std::net::SocketAddr;
99

10+
use bytes::Bytes;
1011
use clap::Clap;
1112
use hyper::service::{make_service_fn, service_fn};
1213
use hyper::{body::to_bytes, header::HeaderValue, Body, Method, Request, Response, StatusCode};
@@ -22,97 +23,102 @@ use crate::metrics::Metrics;
2223
use crate::tasks::{TaskCode, TaskManager};
2324

2425
const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
25-
const NOT_FOUND: &[u8] = b"Not Found";
26+
const RESPONSE_DEFAULT: &[u8] = b"MOSEC service";
27+
const RESPONSE_NOT_FOUND: &[u8] = b"not found";
28+
const RESPONSE_EMPTY: &[u8] = b"no data provided";
29+
const RESPONSE_SHUTDOWN: &[u8] = b"gracefully shutting down";
2630

27-
async fn index(_: Request<Body>) -> Result<Response<Body>, ServiceError> {
31+
async fn index(_: Request<Body>) -> Response<Body> {
2832
let task_manager = TaskManager::global();
2933
if task_manager.is_shutdown() {
30-
return Err(ServiceError::GracefulShutdown);
34+
build_response(
35+
StatusCode::SERVICE_UNAVAILABLE,
36+
Bytes::from_static(RESPONSE_SHUTDOWN),
37+
)
38+
} else {
39+
build_response(StatusCode::OK, Bytes::from_static(RESPONSE_DEFAULT))
3140
}
32-
Ok(Response::new(Body::from("MOSEC service")))
3341
}
3442

35-
async fn metrics(_: Request<Body>) -> Result<Response<Body>, ServiceError> {
43+
async fn metrics(_: Request<Body>) -> Response<Body> {
3644
let encoder = TextEncoder::new();
3745
let metrics = prometheus::gather();
3846
let mut buffer = vec![];
3947
encoder.encode(&metrics, &mut buffer).unwrap();
40-
Ok(Response::new(Body::from(buffer)))
48+
build_response(StatusCode::OK, Bytes::from(buffer))
4149
}
4250

43-
async fn inference(req: Request<Body>) -> Result<Response<Body>, ServiceError> {
51+
async fn inference(req: Request<Body>) -> Response<Body> {
4452
let task_manager = TaskManager::global();
4553
let data = to_bytes(req.into_body()).await.unwrap();
4654
let metrics = Metrics::global();
4755

56+
if task_manager.is_shutdown() {
57+
return build_response(
58+
StatusCode::SERVICE_UNAVAILABLE,
59+
Bytes::from_static(RESPONSE_SHUTDOWN),
60+
);
61+
}
62+
4863
if data.is_empty() {
49-
return Ok(Response::new(Body::from("No data provided")));
64+
return build_response(StatusCode::OK, Bytes::from_static(RESPONSE_EMPTY));
5065
}
5166

67+
let (status, content);
5268
metrics.remaining_task.inc();
53-
let task = task_manager.submit_task(data).await?;
54-
match task.code {
55-
TaskCode::Normal => {
56-
metrics.remaining_task.dec();
57-
metrics
58-
.duration
59-
.with_label_values(&["total", "total"])
60-
.observe(task.create_at.elapsed().as_secs_f64());
61-
metrics
62-
.throughput
63-
.with_label_values(&[StatusCode::OK.as_str()])
64-
.inc();
65-
Ok(Response::new(Body::from(task.data)))
69+
match task_manager.submit_task(data).await {
70+
Ok(task) => {
71+
content = task.data;
72+
status = match task.code {
73+
TaskCode::Normal => {
74+
// Record latency only for successful tasks
75+
metrics
76+
.duration
77+
.with_label_values(&["total", "total"])
78+
.observe(task.create_at.elapsed().as_secs_f64());
79+
StatusCode::OK
80+
}
81+
TaskCode::BadRequestError => StatusCode::BAD_REQUEST,
82+
TaskCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
83+
TaskCode::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
84+
}
85+
}
86+
Err(err) => {
87+
// Handle errors for which tasks cannot be retrieved
88+
content = Bytes::from(err.to_string());
89+
status = match err {
90+
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
91+
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
92+
ServiceError::UnknownError => StatusCode::INTERNAL_SERVER_ERROR,
93+
};
6694
}
67-
TaskCode::BadRequestError => Err(ServiceError::BadRequestError),
68-
TaskCode::ValidationError => Err(ServiceError::ValidationError),
69-
TaskCode::InternalError => Err(ServiceError::InternalError),
70-
TaskCode::UnknownError => Err(ServiceError::UnknownError),
7195
}
72-
}
73-
74-
fn error_handler(err: ServiceError) -> Response<Body> {
75-
let status = match err {
76-
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
77-
ServiceError::BadRequestError => StatusCode::BAD_REQUEST,
78-
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
79-
ServiceError::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
80-
ServiceError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
81-
ServiceError::GracefulShutdown => StatusCode::SERVICE_UNAVAILABLE,
82-
ServiceError::UnknownError => StatusCode::NOT_IMPLEMENTED,
83-
};
84-
let metrics = Metrics::global();
85-
8696
metrics.remaining_task.dec();
8797
metrics
8898
.throughput
8999
.with_label_values(&[status.as_str()])
90100
.inc();
91101

102+
build_response(status, content)
103+
}
104+
105+
fn build_response(status: StatusCode, content: Bytes) -> Response<Body> {
92106
Response::builder()
93107
.status(status)
94108
.header("server", HeaderValue::from_static(SERVER_INFO))
95-
.body(Body::from(err.to_string()))
109+
.body(Body::from(content))
96110
.unwrap()
97111
}
98112

99113
async fn service_func(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
100-
let res = match (req.method(), req.uri().path()) {
101-
(&Method::GET, "/") => index(req).await,
102-
(&Method::GET, "/metrics") => metrics(req).await,
103-
(&Method::POST, "/inference") => inference(req).await,
104-
_ => Ok(Response::builder()
105-
.status(StatusCode::NOT_FOUND)
106-
.body(NOT_FOUND.into())
107-
.unwrap()),
108-
};
109-
match res {
110-
Ok(mut resp) => {
111-
resp.headers_mut()
112-
.insert("server", HeaderValue::from_static(SERVER_INFO));
113-
Ok(resp)
114-
}
115-
Err(err) => Ok(error_handler(err)),
114+
match (req.method(), req.uri().path()) {
115+
(&Method::GET, "/") => Ok(index(req).await),
116+
(&Method::GET, "/metrics") => Ok(metrics(req).await),
117+
(&Method::POST, "/inference") => Ok(inference(req).await),
118+
_ => Ok(build_response(
119+
StatusCode::NOT_FOUND,
120+
Bytes::from(RESPONSE_NOT_FOUND),
121+
)),
116122
}
117123
}
118124

src/protocol.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ const LENGTH_U8_SIZE: usize = 4;
1818
const BIT_STATUS_OK: u16 = 0b1;
1919
const BIT_STATUS_BAD_REQ: u16 = 0b10;
2020
const BIT_STATUS_VALIDATION_ERR: u16 = 0b100;
21-
const BIT_STATUS_INTERNAL_ERR: u16 = 0b1000;
21+
// Others are treated as Internal Error
2222

2323
pub(crate) async fn communicate(
2424
path: PathBuf,
@@ -42,7 +42,7 @@ pub(crate) async fn communicate(
4242
Ok((mut stream, addr)) => {
4343
info!(?addr, "accepted connection from");
4444
tokio::spawn(async move {
45-
let mut code: TaskCode = TaskCode::UnknownError;
45+
let mut code: TaskCode = TaskCode::InternalError;
4646
let mut ids: Vec<u32> = Vec::with_capacity(batch_size);
4747
let mut data: Vec<Bytes> = Vec::with_capacity(batch_size);
4848
let task_manager = TaskManager::global();
@@ -134,10 +134,8 @@ async fn read_message(
134134
TaskCode::BadRequestError
135135
} else if flag & BIT_STATUS_VALIDATION_ERR > 0 {
136136
TaskCode::ValidationError
137-
} else if flag & BIT_STATUS_INTERNAL_ERR > 0 {
138-
TaskCode::InternalError
139137
} else {
140-
TaskCode::UnknownError
138+
TaskCode::InternalError
141139
};
142140

143141
let mut id_buf = [0u8; TASK_ID_U8_SIZE];
@@ -271,7 +269,7 @@ mod tests {
271269
let mut stream = UnixStream::connect(&path).await.unwrap();
272270
let mut recv_ids = Vec::new();
273271
let mut recv_data = Vec::new();
274-
let mut code = TaskCode::UnknownError;
272+
let mut code = TaskCode::InternalError;
275273
read_message(&mut stream, &mut code, &mut recv_ids, &mut recv_data)
276274
.await
277275
.expect("read message error");

0 commit comments

Comments
 (0)