Skip to content

Commit 258fe68

Browse files
committed
feat(derive): add #[postgres(allow_mismatch)]
1 parent 790af54 commit 258fe68

File tree

8 files changed

+250
-27
lines changed

8 files changed

+250
-27
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use postgres_types::{FromSql, ToSql};
2+
3+
#[derive(ToSql, Debug)]
4+
#[postgres(allow_mismatch)]
5+
struct ToSqlAllowMismatchStruct {
6+
a: i32,
7+
}
8+
9+
#[derive(FromSql, Debug)]
10+
#[postgres(allow_mismatch)]
11+
struct FromSqlAllowMismatchStruct {
12+
a: i32,
13+
}
14+
15+
#[derive(ToSql, Debug)]
16+
#[postgres(allow_mismatch)]
17+
struct ToSqlAllowMismatchTupleStruct(i32, i32);
18+
19+
#[derive(FromSql, Debug)]
20+
#[postgres(allow_mismatch)]
21+
struct FromSqlAllowMismatchTupleStruct(i32, i32);
22+
23+
#[derive(FromSql, Debug)]
24+
#[postgres(transparent, allow_mismatch)]
25+
struct TransparentFromSqlAllowMismatchStruct(i32);
26+
27+
#[derive(FromSql, Debug)]
28+
#[postgres(allow_mismatch, transparent)]
29+
struct AllowMismatchFromSqlTransparentStruct(i32);
30+
31+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
error: #[postgres(allow_mismatch)] may only be applied to enums
2+
--> src/compile-fail/invalid-allow-mismatch.rs:4:1
3+
|
4+
4 | / #[postgres(allow_mismatch)]
5+
5 | | struct ToSqlAllowMismatchStruct {
6+
6 | | a: i32,
7+
7 | | }
8+
| |_^
9+
10+
error: #[postgres(allow_mismatch)] may only be applied to enums
11+
--> src/compile-fail/invalid-allow-mismatch.rs:10:1
12+
|
13+
10 | / #[postgres(allow_mismatch)]
14+
11 | | struct FromSqlAllowMismatchStruct {
15+
12 | | a: i32,
16+
13 | | }
17+
| |_^
18+
19+
error: #[postgres(allow_mismatch)] may only be applied to enums
20+
--> src/compile-fail/invalid-allow-mismatch.rs:16:1
21+
|
22+
16 | / #[postgres(allow_mismatch)]
23+
17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32);
24+
| |_______________________________________________^
25+
26+
error: #[postgres(allow_mismatch)] may only be applied to enums
27+
--> src/compile-fail/invalid-allow-mismatch.rs:20:1
28+
|
29+
20 | / #[postgres(allow_mismatch)]
30+
21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32);
31+
| |_________________________________________________^
32+
33+
error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]
34+
--> src/compile-fail/invalid-allow-mismatch.rs:24:25
35+
|
36+
24 | #[postgres(transparent, allow_mismatch)]
37+
| ^^^^^^^^^^^^^^
38+
39+
error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]
40+
--> src/compile-fail/invalid-allow-mismatch.rs:28:28
41+
|
42+
28 | #[postgres(allow_mismatch, transparent)]
43+
| ^^^^^^^^^^^

postgres-derive-test/src/enums.rs

+71-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::test_type;
2-
use postgres::{Client, NoTls};
2+
use postgres::{error::DbError, Client, NoTls};
33
use postgres_types::{FromSql, ToSql, WrongType};
44
use std::error::Error;
55

@@ -131,3 +131,73 @@ fn missing_variant() {
131131
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
132132
assert!(err.source().unwrap().is::<WrongType>());
133133
}
134+
135+
#[test]
136+
fn allow_mismatch_enums() {
137+
#[derive(Debug, ToSql, FromSql, PartialEq)]
138+
#[postgres(allow_mismatch)]
139+
enum Foo {
140+
Bar,
141+
}
142+
143+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
144+
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
145+
.unwrap();
146+
147+
let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap();
148+
assert_eq!(row.get::<_, Foo>(0), Foo::Bar);
149+
}
150+
151+
#[test]
152+
fn missing_enum_variant() {
153+
#[derive(Debug, ToSql, FromSql, PartialEq)]
154+
#[postgres(allow_mismatch)]
155+
enum Foo {
156+
Bar,
157+
Buz,
158+
}
159+
160+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
161+
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
162+
.unwrap();
163+
164+
let err = conn
165+
.query_one("SELECT $1::\"Foo\"", &[&Foo::Buz])
166+
.unwrap_err();
167+
assert!(err.source().unwrap().is::<DbError>());
168+
}
169+
170+
#[test]
171+
fn allow_mismatch_and_renaming() {
172+
#[derive(Debug, ToSql, FromSql, PartialEq)]
173+
#[postgres(name = "foo", allow_mismatch)]
174+
enum Foo {
175+
#[postgres(name = "bar")]
176+
Bar,
177+
#[postgres(name = "buz")]
178+
Buz,
179+
}
180+
181+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
182+
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[])
183+
.unwrap();
184+
185+
let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap();
186+
assert_eq!(row.get::<_, Foo>(0), Foo::Buz);
187+
}
188+
189+
#[test]
190+
fn wrong_name_and_allow_mismatch() {
191+
#[derive(Debug, ToSql, FromSql, PartialEq)]
192+
#[postgres(allow_mismatch)]
193+
enum Foo {
194+
Bar,
195+
}
196+
197+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
198+
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
199+
.unwrap();
200+
201+
let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
202+
assert!(err.source().unwrap().is::<WrongType>());
203+
}

postgres-derive/src/accepts.rs

+24-18
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
3131
}
3232
}
3333

34-
pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
34+
pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream {
3535
let num_variants = variants.len();
3636
let variant_names = variants.iter().map(|v| &v.name);
3737

38-
quote! {
39-
if type_.name() != #name {
40-
return false;
38+
if allow_mismatch {
39+
quote! {
40+
type_.name() == #name
4141
}
42+
} else {
43+
quote! {
44+
if type_.name() != #name {
45+
return false;
46+
}
4247

43-
match *type_.kind() {
44-
::postgres_types::Kind::Enum(ref variants) => {
45-
if variants.len() != #num_variants {
46-
return false;
47-
}
48-
49-
variants.iter().all(|v| {
50-
match &**v {
51-
#(
52-
#variant_names => true,
53-
)*
54-
_ => false,
48+
match *type_.kind() {
49+
::postgres_types::Kind::Enum(ref variants) => {
50+
if variants.len() != #num_variants {
51+
return false;
5552
}
56-
})
53+
54+
variants.iter().all(|v| {
55+
match &**v {
56+
#(
57+
#variant_names => true,
58+
)*
59+
_ => false,
60+
}
61+
})
62+
}
63+
_ => false,
5764
}
58-
_ => false,
5965
}
6066
}
6167
}

postgres-derive/src/fromsql.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
4848
))
4949
}
5050
}
51+
} else if overrides.allow_mismatch {
52+
match input.data {
53+
Data::Enum(ref data) => {
54+
let variants = data
55+
.variants
56+
.iter()
57+
.map(|variant| Variant::parse(variant, overrides.rename_all))
58+
.collect::<Result<Vec<_>, _>>()?;
59+
(
60+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
61+
enum_body(&input.ident, &variants),
62+
)
63+
}
64+
_ => {
65+
return Err(Error::new_spanned(
66+
input,
67+
"#[postgres(allow_mismatch)] may only be applied to enums",
68+
));
69+
}
70+
}
5171
} else {
5272
match input.data {
5373
Data::Enum(ref data) => {
@@ -57,7 +77,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
5777
.map(|variant| Variant::parse(variant, overrides.rename_all))
5878
.collect::<Result<Vec<_>, _>>()?;
5979
(
60-
accepts::enum_body(&name, &variants),
80+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
6181
enum_body(&input.ident, &variants),
6282
)
6383
}

postgres-derive/src/overrides.rs

+19-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub struct Overrides {
77
pub name: Option<String>,
88
pub rename_all: Option<RenameRule>,
99
pub transparent: bool,
10+
pub allow_mismatch: bool,
1011
}
1112

1213
impl Overrides {
@@ -15,6 +16,7 @@ impl Overrides {
1516
name: None,
1617
rename_all: None,
1718
transparent: false,
19+
allow_mismatch: false,
1820
};
1921

2022
for attr in attrs {
@@ -74,11 +76,25 @@ impl Overrides {
7476
}
7577
}
7678
Meta::Path(path) => {
77-
if !path.is_ident("transparent") {
79+
if path.is_ident("transparent") {
80+
if overrides.allow_mismatch {
81+
return Err(Error::new_spanned(
82+
path,
83+
"#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]",
84+
));
85+
}
86+
overrides.transparent = true;
87+
} else if path.is_ident("allow_mismatch") {
88+
if overrides.transparent {
89+
return Err(Error::new_spanned(
90+
path,
91+
"#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]",
92+
));
93+
}
94+
overrides.allow_mismatch = true;
95+
} else {
7896
return Err(Error::new_spanned(path, "unknown override"));
7997
}
80-
81-
overrides.transparent = true;
8298
}
8399
bad => return Err(Error::new_spanned(bad, "unknown attribute")),
84100
}

postgres-derive/src/tosql.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
4444
));
4545
}
4646
}
47+
} else if overrides.allow_mismatch {
48+
match input.data {
49+
Data::Enum(ref data) => {
50+
let variants = data
51+
.variants
52+
.iter()
53+
.map(|variant| Variant::parse(variant, overrides.rename_all))
54+
.collect::<Result<Vec<_>, _>>()?;
55+
(
56+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
57+
enum_body(&input.ident, &variants),
58+
)
59+
}
60+
_ => {
61+
return Err(Error::new_spanned(
62+
input,
63+
"#[postgres(allow_mismatch)] may only be applied to enums",
64+
));
65+
}
66+
}
4767
} else {
4868
match input.data {
4969
Data::Enum(ref data) => {
@@ -53,7 +73,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
5373
.map(|variant| Variant::parse(variant, overrides.rename_all))
5474
.collect::<Result<Vec<_>, _>>()?;
5575
(
56-
accepts::enum_body(&name, &variants),
76+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
5777
enum_body(&input.ident, &variants),
5878
)
5979
}

postgres-types/src/lib.rs

+20-3
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@
138138
//! #[derive(Debug, ToSql, FromSql)]
139139
//! #[postgres(name = "mood", rename_all = "snake_case")]
140140
//! enum Mood {
141-
//! VerySad, // very_sad
142141
//! #[postgres(name = "ok")]
143142
//! Ok, // ok
144143
//! VeryHappy, // very_happy
@@ -155,10 +154,28 @@
155154
//! - `"kebab-case"`
156155
//! - `"SCREAMING-KEBAB-CASE"`
157156
//! - `"Train-Case"`
158-
157+
//!
158+
//! ## Allowing Enum Mismatches
159+
//!
160+
//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum
161+
//! variants between the Rust and Postgres types.
162+
//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition:
163+
//!
164+
//! ```sql
165+
//! CREATE TYPE mood AS ENUM (
166+
//! 'Sad',
167+
//! 'Ok',
168+
//! 'Happy'
169+
//! );
170+
//! ```
171+
//! #[postgres(allow_mismatch)]
172+
//! enum Mood {
173+
//! Happy,
174+
//! Meh,
175+
//! }
176+
//! ```
159177
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
160178
#![warn(clippy::all, rust_2018_idioms, missing_docs)]
161-
162179
use fallible_iterator::FallibleIterator;
163180
use postgres_protocol::types::{self, ArrayDimension};
164181
use std::any::type_name;

0 commit comments

Comments
 (0)