Skip to content

Commit

Permalink
Merge pull request twitter#421 from brirams/expose_file_descriptor
Browse files Browse the repository at this point in the history
Expose FileDescriptor and FieldDescriptor in ThriftToDynamicProto
  • Loading branch information
rangadi committed Oct 3, 2014
2 parents 6a690fd + 12b49ab commit 72e6e4b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.protobuf.Message;

import org.apache.thrift.TBase;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TType;

import org.slf4j.Logger;
Expand Down Expand Up @@ -54,6 +55,8 @@ public class ThriftToDynamicProto<T extends TBase<?, ?>> {
private boolean supportNestedObjects = false;
private boolean ignoreUnsupportedTypes = false;

private final Descriptors.FileDescriptor fileDescriptor;

// a map of descriptors keyed to their typeName
private Map<String, DescriptorProtos.DescriptorProto.Builder> descriptorBuilderMap
= Maps.newHashMap();
Expand Down Expand Up @@ -136,14 +139,16 @@ public ThriftToDynamicProto(Class<T> thriftClass,
Descriptors.Descriptor msgDescriptor = dynamicDescriptor.findMessageTypeByName(type);
messageBuilderMap.put(type, DynamicMessage.newBuilder(msgDescriptor));
}

fileDescriptor = dynamicDescriptor;
}

/**
* For the given thriftClass, return a Protobufs builder to build a similar protobuf class.
* @param thriftClass The thrift class for which the builder is desired.
* @return a protobuf message builder
*/
public Message.Builder getBuilder(Class<? extends TBase> thriftClass) {
public Message.Builder getBuilder(Class<? extends TBase<?, ?>> thriftClass) {
return messageBuilderMap.get(protoMessageType(thriftClass)).clone();
}

Expand Down Expand Up @@ -343,37 +348,38 @@ private DescriptorProtos.FieldDescriptorProto.Builder fieldDescriptorProtoBuilde
* @param thriftObj thrift object
* @return protobuf protobuf message
*/
@SuppressWarnings("unchecked")
public Message convert(T thriftObj) {
return doConvert(Preconditions.checkNotNull(thriftObj, "Can not convert a null object"));
return doConvert((TBase<?, ?>)
Preconditions.checkNotNull(thriftObj, "Can not convert a null object"));
}

/**
* conver TBase object to Message object
* @param thriftObj
*/
@SuppressWarnings("unchecked")
public Message doConvert(TBase thriftObj) {
public <F extends TFieldIdEnum> Message doConvert(TBase<?, F> thriftObj) {
if (thriftObj == null) { return null; }

Preconditions.checkState(hasBuilder(thriftObj.getClass()),
"No message builder found for thrift class: " + thriftObj.getClass().getCanonicalName());
Class<TBase<?, F>> clazz = (Class<TBase<?, F>>) thriftObj.getClass();
checkState(clazz);

Message.Builder builder = getBuilder(thriftObj.getClass());
Message.Builder builder = getBuilder(clazz);

TStructDescriptor fieldDesc = TStructDescriptor.getInstance(
(Class<? extends TBase<?, ?>>) thriftObj.getClass());
TStructDescriptor fieldDesc = TStructDescriptor.getInstance(clazz);
int fieldId = 0;
for (Field tField : fieldDesc.getFields()) {
// don't want to carry over default values from unset fields
if (!thriftObj.isSet(tField.getFieldIdEnum())
if (!thriftObj.isSet((F) tField.getFieldIdEnum())
|| (!supportNestedObjects && hasNestedObject(tField))) {
fieldId++;
continue;
}

// recurse into the object if it's a struct, otherwise just add the field
if (supportNestedObjects && tField.getType() == TType.STRUCT) {
TBase fieldValue = (TBase) fieldDesc.getFieldValue(fieldId++, thriftObj);
TBase<?, ?> fieldValue = (TBase<?, ?>) fieldDesc.getFieldValue(fieldId++, thriftObj);
Message message = doConvert(fieldValue);
if (message != null) {
FieldDescriptor protoFieldDesc = builder.getDescriptorForType().findFieldByName(
Expand All @@ -387,7 +393,12 @@ public Message doConvert(TBase thriftObj) {
return builder.build();
}

private boolean hasBuilder(Class<? extends TBase> thriftClass) {
private void checkState(Class<? extends TBase<?, ?>> thriftClass) {
Preconditions.checkState(hasBuilder(thriftClass),
"No message builder found for thrift class: " + thriftClass.getCanonicalName());
}

private boolean hasBuilder(Class<? extends TBase<?, ?>> thriftClass) {
return messageBuilderMap.get(protoMessageType(thriftClass)) != null;
}

Expand Down Expand Up @@ -418,7 +429,8 @@ private boolean isStructContainer(Field tField) {
|| (tField.isSet() && tField.getSetElemField().isStruct());
}

private int convertField(TBase thriftObj, Message.Builder builder,
@SuppressWarnings("unchecked")
private int convertField(TBase<?, ?> thriftObj, Message.Builder builder,
TStructDescriptor fieldDesc, int fieldId, Field tField) {
int tmpFieldId = fieldId;
FieldDescriptor protoFieldDesc = builder.getDescriptorForType().findFieldByName(
Expand Down Expand Up @@ -446,8 +458,8 @@ private int convertField(TBase thriftObj, Message.Builder builder,
if (isStructContainer(tField)) {
List<Message> convertedStructs = Lists.newLinkedList();

Iterable<TBase> structIterator = (Iterable<TBase>) fieldValue;
for (TBase struct : structIterator) {
Iterable<TBase<?, ?>> structIterator = (Iterable<TBase<?, ?>>) fieldValue;
for (TBase<?, ?> struct : structIterator) {
convertedStructs.add(doConvert(struct));
}

Expand Down Expand Up @@ -504,14 +516,14 @@ private Message buildMapEntryMessage(Message.Builder mapBuilder, Field field,

Object convertedKey;
if (isKeyStruct) {
convertedKey = doConvert((TBase) mapKey);
convertedKey = doConvert((TBase<?, ?>) mapKey);
} else {
convertedKey = sanitizeRawValue(mapKey, field.getMapKeyField());
}

Object convertedValue;
if (isValueStruct) {
convertedValue = doConvert((TBase) mapValue);
convertedValue = doConvert((TBase<?, ?>) mapValue);
} else {
convertedValue = sanitizeRawValue(mapValue, field.getMapValueField());
}
Expand Down Expand Up @@ -570,7 +582,7 @@ private Type thriftTypeToProtoType(Field tField) {

// name the proto message type after the thrift class name. Dots are not permitted in protobuf
// names
private String protoMessageType(Class<? extends TBase> thriftClass) {
private String protoMessageType(Class<? extends TBase<?, ?>> thriftClass) {
return thriftClass.getCanonicalName().replace(".", "_");
}

Expand All @@ -581,4 +593,18 @@ private String protoMessageType(Class<? extends TBase> thriftClass) {
private String mapProtoMessageType(TStructDescriptor descriptor, Field field) {
return String.format("%s_%s", protoMessageType(descriptor.getThriftClass()), field.getName());
}

// Given the class name, finds the corresponding Descriptor and return the appropriate
// FieldDescriptor
public FieldDescriptor getFieldDescriptor(Class<? extends TBase<?, ?>> thriftClass,
String fieldName) {
checkState(thriftClass);
Descriptors.Descriptor descriptor = getBuilder(thriftClass).getDescriptorForType();
return descriptor.findFieldByName(fieldName);
}

// Picks off the FileDescriptor for this instance
public Descriptors.FileDescriptor getFileDescriptor() {
return fileDescriptor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Message;

import com.twitter.elephantbird.thrift.test.AddressBook;
Expand All @@ -21,13 +23,19 @@
import com.twitter.elephantbird.thrift.test.PrimitiveSetsStruct;
import com.twitter.elephantbird.util.Protobufs;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

public class TestThriftToDynamicProto {

@Rule
public ExpectedException exception = ExpectedException.none();

private PhoneNumber genPhoneNumber(String number, PhoneType type) {
PhoneNumber phoneNumber = new PhoneNumber(number);
phoneNumber.setType(type);
Expand Down Expand Up @@ -234,4 +242,36 @@ public void testMapConversionWhenNestedStructsDisabled() throws DescriptorValida
Message msg = thriftToProto.convert(mapStruct);
assertTrue(!Protobufs.hasFieldByName(msg, "entries"));
}

@Test
public void testBadThriftTypeForGetFieldDescriptor() throws DescriptorValidationException {
ThriftToDynamicProto<PhoneNumber> converter = new ThriftToDynamicProto<PhoneNumber>(PhoneNumber.class);

exception.expect(IllegalStateException.class);
converter.getFieldDescriptor(Person.class, "some_field");
}

@Test
public void testGetFieldTypeDescriptor() throws DescriptorValidationException {
ThriftToDynamicProto<Person> converter = new ThriftToDynamicProto<Person>(Person.class);
Person person = genPerson();
Message msg = converter.convert(person);

FieldDescriptor expectedFd = msg.getDescriptorForType().findFieldByName("email");
FieldDescriptor actualFd = converter.getFieldDescriptor(Person.class, "email");

assertEquals(expectedFd, actualFd);
}

@Test
public void testGetFileDescriptor() throws DescriptorValidationException {
ThriftToDynamicProto<Person> converter = new ThriftToDynamicProto<Person>(Person.class);
Person person = genPerson();
Message msg = converter.convert(person);

FileDescriptor expectedFd = msg.getDescriptorForType().getFile();
FileDescriptor actualFd = converter.getFileDescriptor();

assertEquals(expectedFd, actualFd);
}
}

0 comments on commit 72e6e4b

Please sign in to comment.