diff --git a/util/EIP-712/src/eip712.rs b/util/EIP-712/src/eip712.rs index 426303403..836e1dbb9 100644 --- a/util/EIP-712/src/eip712.rs +++ b/util/EIP-712/src/eip712.rs @@ -19,8 +19,7 @@ use serde_json::{Value}; use std::collections::HashMap; use ethereum_types::{U256, H256, Address}; use regex::Regex; -use validator::Validate; -use validator::ValidationErrors; +use validator::{Validate, ValidationError, ValidationErrors}; use lazy_static::lazy_static; pub(crate) type MessageTypes = HashMap>; @@ -32,16 +31,28 @@ lazy_static! { } #[serde(rename_all = "camelCase")] -#[serde(deny_unknown_fields)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)] +#[validate(schema(function = "validate_domain"))] pub(crate) struct EIP712Domain { - pub(crate) name: String, - pub(crate) version: String, - pub(crate) chain_id: U256, - pub(crate) verifying_contract: Address, + #[serde(skip_serializing_if="Option::is_none")] + pub(crate) name: Option, + #[serde(skip_serializing_if="Option::is_none")] + pub(crate) version: Option, + #[serde(skip_serializing_if="Option::is_none")] + pub(crate) chain_id: Option, + #[serde(skip_serializing_if="Option::is_none")] + pub(crate) verifying_contract: Option
, #[serde(skip_serializing_if="Option::is_none")] pub(crate) salt: Option, } + +fn validate_domain(domain: &EIP712Domain) -> Result<(), ValidationError> { + match (domain.name.as_ref(), domain.version.as_ref(), domain.chain_id, domain.verifying_contract, domain.salt) { + (None, None, None, None, None) => Err(ValidationError::new("EIP712Domain must include at least one field")), + _ => Ok(()) + } +} + /// EIP-712 struct #[serde(rename_all = "camelCase")] #[serde(deny_unknown_fields)] @@ -55,6 +66,7 @@ pub struct EIP712 { impl Validate for EIP712 { fn validate(&self) -> Result<(), ValidationErrors> { + self.domain.validate()?; for field_types in self.types.values() { for field_type in field_types { field_type.validate()?; @@ -159,7 +171,8 @@ mod tests { { "name": "name", "type": "string" }, { "name": "version", "type": "string" }, { "name": "chainId", "type": "7uint256[x] Seun" }, - { "name": "verifyingContract", "type": "address" } + { "name": "verifyingContract", "type": "address" }, + { "name": "salt", "type": "bytes32" } ], "Person": [ { "name": "name", "type": "string" }, @@ -175,4 +188,59 @@ mod tests { let data = from_str::(string).unwrap(); assert_eq!(data.validate().is_err(), true); } + + #[test] + fn test_valid_domain() { + let string = r#"{ + "primaryType": "Test", + "domain": { + "name": "Ether Mail", + "version": "1", + "chainId": "0x1", + "verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC", + "salt": "0x0000000000000000000000000000000000000000000000000000000000000001" + }, + "message": { + "test": "It works!" + }, + "types": { + "EIP712Domain": [ + { "name": "name", "type": "string" }, + { "name": "version", "type": "string" }, + { "name": "chainId", "type": "uint256" }, + { "name": "verifyingContract", "type": "address" }, + { "name": "salt", "type": "bytes32" } + ], + "Test": [ + { "name": "test", "type": "string" } + ] + } + }"#; + let data = from_str::(string).unwrap(); + assert_eq!(data.validate().is_err(), false); + } + + #[test] + fn domain_needs_at_least_one_field() { + let string = r#"{ + "primaryType": "Test", + "domain": {}, + "message": { + "test": "It works!" + }, + "types": { + "EIP712Domain": [ + { "name": "name", "type": "string" }, + { "name": "version", "type": "string" }, + { "name": "chainId", "type": "uint256" }, + { "name": "verifyingContract", "type": "address" } + ], + "Test": [ + { "name": "test", "type": "string" } + ] + } + }"#; + let data = from_str::(string).unwrap(); + assert_eq!(data.validate().is_err(), true); + } }