codegen expansion for traits

This commit is contained in:
NikVolf 2016-07-12 17:45:35 +02:00
parent 2310ecb480
commit f380340a9b
5 changed files with 70 additions and 34 deletions

View File

@ -28,6 +28,7 @@ use syntax::ast::{
TraitRef, TraitRef,
Ident, Ident,
Generics, Generics,
TraitItemKind,
}; };
use syntax::ast; use syntax::ast;
@ -80,19 +81,23 @@ pub fn replace_slice_u8(builder: &aster::AstBuilder, ty: &P<ast::Ty>) -> P<ast::
ty.clone() ty.clone()
} }
struct NamedSignature<'a> {
sig: &'a MethodSig,
ident: &'a Ident,
}
fn push_invoke_signature_aster( fn push_invoke_signature_aster(
builder: &aster::AstBuilder, builder: &aster::AstBuilder,
implement: &ImplItem, named_signature: &NamedSignature,
signature: &MethodSig,
push: &mut FnMut(Annotatable), push: &mut FnMut(Annotatable),
) -> Dispatch { ) -> Dispatch {
let inputs = &signature.decl.inputs; let inputs = &named_signature.sig.decl.inputs;
let (input_type_name, input_arg_names, input_arg_tys) = if inputs.len() > 0 { let (input_type_name, input_arg_names, input_arg_tys) = if inputs.len() > 0 {
let first_field_name = field_name(builder, &inputs[0]).name.as_str(); let first_field_name = field_name(builder, &inputs[0]).name.as_str();
if first_field_name == "self" && inputs.len() == 1 { (None, vec![], vec![]) } if first_field_name == "self" && inputs.len() == 1 { (None, vec![], vec![]) }
else { else {
let skip = if first_field_name == "self" { 2 } else { 1 }; let skip = if first_field_name == "self" { 2 } else { 1 };
let name_str = format!("{}_input", implement.ident.name.as_str()); let name_str = format!("{}_input", named_signature.ident.name.as_str());
let mut arg_names = Vec::new(); let mut arg_names = Vec::new();
let mut arg_tys = Vec::new(); let mut arg_tys = Vec::new();
@ -126,9 +131,9 @@ fn push_invoke_signature_aster(
(None, vec![], vec![]) (None, vec![], vec![])
}; };
let return_type_ty = match signature.decl.output { let return_type_ty = match named_signature.sig.decl.output {
FunctionRetTy::Ty(ref ty) => { FunctionRetTy::Ty(ref ty) => {
let name_str = format!("{}_output", implement.ident.name.as_str()); let name_str = format!("{}_output", named_signature.ident.name.as_str());
let tree = builder.item() let tree = builder.item()
.attr().word("derive(Binary)") .attr().word("derive(Binary)")
.attr().word("allow(non_camel_case_types)") .attr().word("allow(non_camel_case_types)")
@ -141,7 +146,7 @@ fn push_invoke_signature_aster(
}; };
Dispatch { Dispatch {
function_name: format!("{}", implement.ident.name.as_str()), function_name: format!("{}", named_signature.ident.name.as_str()),
input_type_name: input_type_name, input_type_name: input_type_name,
input_arg_names: input_arg_names, input_arg_names: input_arg_names,
input_arg_tys: input_arg_tys, input_arg_tys: input_arg_tys,
@ -587,8 +592,8 @@ fn push_client_implementation(
let handshake_item = quote_impl_item!(cx, let handshake_item = quote_impl_item!(cx,
pub fn handshake(&self) -> Result<(), ::ipc::Error> { pub fn handshake(&self) -> Result<(), ::ipc::Error> {
let payload = ::ipc::Handshake { let payload = ::ipc::Handshake {
protocol_version: Arc::<$endpoint>::protocol_version(), protocol_version: ::std::sync::Arc::<$endpoint>::protocol_version(),
api_version: Arc::<$endpoint>::api_version(), api_version: ::std::sync::Arc::<$endpoint>::api_version(),
}; };
::ipc::invoke( ::ipc::invoke(
@ -771,8 +776,43 @@ fn implement_interface(
item: &Item, item: &Item,
push: &mut FnMut(Annotatable), push: &mut FnMut(Annotatable),
) -> Result<InterfaceMap, Error> { ) -> Result<InterfaceMap, Error> {
let (generics, impl_trait, original_ty, impl_items) = match item.node { let (generics, impl_trait, original_ty, dispatch_table) = match item.node {
ast::ItemKind::Impl(_, _, ref generics, ref impl_trait, ref ty, ref impl_items) => (generics, impl_trait, ty, impl_items), ast::ItemKind::Impl(_, _, ref generics, ref impl_trait, ref ty, ref impl_items) => {
let mut method_signatures = Vec::new();
for impl_item in impl_items {
if let ImplItemKind::Method(ref signature, _) = impl_item.node {
method_signatures.push(NamedSignature { ident: &impl_item.ident, sig: signature });
}
}
let dispatch_table = method_signatures.iter().map(|named_signature|
push_invoke_signature_aster(builder, named_signature, push))
.collect::<Vec<Dispatch>>();
(generics, impl_trait.clone(), ty.clone(), dispatch_table)
},
ast::ItemKind::Trait(_, ref generics, _, ref trait_items) => {
let mut method_signatures = Vec::new();
for trait_item in trait_items {
if let TraitItemKind::Method(ref signature, _) = trait_item.node {
method_signatures.push(NamedSignature { ident: &trait_item.ident, sig: signature });
}
}
let dispatch_table = method_signatures.iter().map(|named_signature|
push_invoke_signature_aster(builder, named_signature, push))
.collect::<Vec<Dispatch>>();
(
generics,
Some(ast::TraitRef {
path: builder.path().ids(&[item.ident.name]).build(),
ref_id: item.id,
}),
builder.ty().id(item.ident),
dispatch_table
)
},
_ => { _ => {
cx.span_err( cx.span_err(
item.span, item.span,
@ -783,30 +823,19 @@ fn implement_interface(
let impl_generics = builder.from_generics(generics.clone()).build(); let impl_generics = builder.from_generics(generics.clone()).build();
let where_clause = &impl_generics.where_clause; let where_clause = &impl_generics.where_clause;
let mut method_signatures = Vec::new();
for impl_item in impl_items {
if let ImplItemKind::Method(ref signature, _) = impl_item.node {
method_signatures.push((impl_item, signature))
}
}
let dispatch_table = method_signatures.iter().map(|&(impl_item, signature)|
push_invoke_signature_aster(builder, impl_item, signature, push))
.collect::<Vec<Dispatch>>();
let dispatch_arms = implement_dispatch_arms(cx, builder, &dispatch_table, false); let dispatch_arms = implement_dispatch_arms(cx, builder, &dispatch_table, false);
let dispatch_arms_buffered = implement_dispatch_arms(cx, builder, &dispatch_table, true); let dispatch_arms_buffered = implement_dispatch_arms(cx, builder, &dispatch_table, true);
let (handshake_arm, handshake_arm_buf) = implement_handshake_arm(cx); let (handshake_arm, handshake_arm_buf) = implement_handshake_arm(cx);
let ty = ty_ident_map(&original_ty).ident(builder); let ty = ty_ident_map(&original_ty).ident(builder);
let interface_endpoint = match *impl_trait { let (interface_endpoint, host_generics) = match impl_trait {
Some(ref trait_) => builder.id(::syntax::print::pprust::path_to_string(&trait_.path)), Some(ref trait_) => (builder.id(::syntax::print::pprust::path_to_string(&trait_.path)), None),
None => ty None => (ty, Some(&impl_generics)),
}; };
let ipc_item = quote_item!(cx, let ipc_item = quote_item!(cx,
impl $impl_generics ::ipc::IpcInterface<$interface_endpoint> for Arc<$interface_endpoint> $where_clause { impl $host_generics ::ipc::IpcInterface<$interface_endpoint> for ::std::sync::Arc<$interface_endpoint> $where_clause {
fn dispatch<R>(&self, r: &mut R) -> Vec<u8> fn dispatch<R>(&self, r: &mut R) -> Vec<u8>
where R: ::std::io::Read where R: ::std::io::Read
{ {

View File

@ -19,10 +19,11 @@ mod tests {
use super::super::service::*; use super::super::service::*;
use super::super::binary::*; use super::super::binary::*;
use super::super::nested::{DBClient,DBWriter}; use super::super::nested::{DBClient, DBWriter};
use ipc::*; use ipc::*;
use devtools::*; use devtools::*;
use semver::Version; use semver::Version;
use std::sync::Arc;
#[test] #[test]
fn call_service() { fn call_service() {
@ -33,7 +34,7 @@ mod tests {
4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
10, 0, 0, 0]); 10, 0, 0, 0]);
let service = Service::new(); let service = Arc::new(Service::new());
assert_eq!(0, *service.commits.read().unwrap()); assert_eq!(0, *service.commits.read().unwrap());
service.dispatch(&mut socket); service.dispatch(&mut socket);
@ -65,7 +66,7 @@ mod tests {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
]); ]);
let service = Service::new(); let service = Arc::new(Service::new());
let result = service.dispatch(&mut socket); let result = service.dispatch(&mut socket);
// single `true` // single `true`
@ -109,9 +110,9 @@ mod tests {
#[test] #[test]
fn query_default_version() { fn query_default_version() {
let ver = Service::protocol_version(); let ver = Arc::<Service>::protocol_version();
assert_eq!(ver, Version::parse("1.0.0").unwrap()); assert_eq!(ver, Version::parse("1.0.0").unwrap());
let ver = Service::api_version(); let ver = Arc::<Service>::api_version();
assert_eq!(ver, Version::parse("1.0.0").unwrap()); assert_eq!(ver, Version::parse("1.0.0").unwrap());
} }

View File

@ -33,7 +33,7 @@ pub trait DBWriter {
fn write_slice(&self, data: &[u8]) -> Result<(), DBError>; fn write_slice(&self, data: &[u8]) -> Result<(), DBError>;
} }
impl<L: Sized> IpcConfig for DB<L> {} impl IpcConfig<DBWriter> for ::std::sync::Arc<DBWriter> {}
#[derive(Binary)] #[derive(Binary)]
pub enum DBError { Write, Read } pub enum DBError { Write, Read }
@ -53,3 +53,9 @@ impl<L: Sized> DBWriter for DB<L> {
} }
} }
#[derive(Ipc)]
trait DBNotify {
fn notify(&self, a: u64, b: u64) -> bool;
}
impl IpcConfig<DBNotify> for ::std::sync::Arc<DBNotify> { }

View File

@ -70,4 +70,4 @@ impl Service {
} }
} }
impl ::ipc::IpcConfig for Service {} impl ::ipc::IpcConfig<Service> for ::std::sync::Arc<Service> {}

View File

@ -31,4 +31,4 @@ impl BadlyNamedService {
} }
} }
impl ::ipc::IpcConfig for BadlyNamedService {} impl ::ipc::IpcConfig<BadlyNamedService> for ::std::sync::Arc<BadlyNamedService> {}