Commit 75bcccd1 by Jeff Hagelberg

ATLAS-1386: Avoid uunnecessary type cache lookups

parent b6c408b3
...@@ -52,7 +52,7 @@ public enum AtlasErrorCode { ...@@ -52,7 +52,7 @@ public enum AtlasErrorCode {
PATCH_NOT_APPLICABLE_FOR_TYPE(400, "ATLAS40022E", "{0} - invalid patch for type {1}"), PATCH_NOT_APPLICABLE_FOR_TYPE(400, "ATLAS40022E", "{0} - invalid patch for type {1}"),
PATCH_FOR_UNKNOWN_TYPE(400, "ATLAS40023E", "{0} - patch references unknown type {1}"), PATCH_FOR_UNKNOWN_TYPE(400, "ATLAS40023E", "{0} - patch references unknown type {1}"),
PATCH_INVALID_DATA(400, "ATLAS40024E", "{0} - patch data is invalid for type {1}"), PATCH_INVALID_DATA(400, "ATLAS40024E", "{0} - patch data is invalid for type {1}"),
TYPE_NAME_INVALID_FORMAT(400, "ATLAS40025E", "{0}: invalid name for {1}. Only alphanumeric and '_' are allowed."), TYPE_NAME_INVALID_FORMAT(400, "ATLAS40025E", "{0}: invalid name for {1}. Names must consist of a letter followed by a sequence of letter, number, or '_' characters"),
// All Not found enums go here // All Not found enums go here
TYPE_NAME_NOT_FOUND(404, "ATLAS4041E", "Given typename {0} was invalid"), TYPE_NAME_NOT_FOUND(404, "ATLAS4041E", "Given typename {0} was invalid"),
......
...@@ -61,8 +61,8 @@ public class AtlasTypeUtil { ...@@ -61,8 +61,8 @@ public class AtlasTypeUtil {
private static final Pattern NAME_PATTERN = Pattern.compile(NAME_REGEX); private static final Pattern NAME_PATTERN = Pattern.compile(NAME_REGEX);
private static final Pattern TRAIT_NAME_PATTERN = Pattern.compile(TRAIT_NAME_REGEX); private static final Pattern TRAIT_NAME_PATTERN = Pattern.compile(TRAIT_NAME_REGEX);
private static final String InvalidTypeNameErrorMessage = "Only alphanumeric characters, numbers and '_' are allowed in names."; private static final String InvalidTypeNameErrorMessage = "Names must consist of a letter followed by a sequence of letter, number, or '_' characters.";
private static final String InvalidTraitTypeNameErrorMessage = "Only alphanumeric characters, numbers, '.' and '_' are allowed in names."; private static final String InvalidTraitTypeNameErrorMessage = "Names must consist of a leter followed by a sequence of letters, numbers, '.', or '_' characters.";
static { static {
Collections.addAll(ATLAS_BUILTIN_TYPENAMES, AtlasBaseTypeDef.ATLAS_BUILTIN_TYPES); Collections.addAll(ATLAS_BUILTIN_TYPENAMES, AtlasBaseTypeDef.ATLAS_BUILTIN_TYPES);
......
...@@ -9,6 +9,7 @@ ATLAS-1060 Add composite indexes for exact match performance improvements for al ...@@ -9,6 +9,7 @@ ATLAS-1060 Add composite indexes for exact match performance improvements for al
ATLAS-1127 Modify creation and modification timestamps to Date instead of Long(sumasai) ATLAS-1127 Modify creation and modification timestamps to Date instead of Long(sumasai)
ALL CHANGES: ALL CHANGES:
ATLAS-1386 Avoid uunnecessary type cache lookups (jnhagelb)
ATLAS-1464 option to include only specified attributes in notification message (sarath.kum4r@gmail.com via mneethiraj) ATLAS-1464 option to include only specified attributes in notification message (sarath.kum4r@gmail.com via mneethiraj)
ATLAS-1460 v2 search API updated to return name/description/owner and classification names in result (vimalsharma via mneethiraj) ATLAS-1460 v2 search API updated to return name/description/owner and classification names in result (vimalsharma via mneethiraj)
ATLAS-1434 fixed unit test to use correct type names; updated error message per review comments (ashutoshm via mneethiraj) ATLAS-1434 fixed unit test to use correct type names; updated error message per review comments (ashutoshm via mneethiraj)
......
...@@ -73,8 +73,8 @@ trait ClosureQuery { ...@@ -73,8 +73,8 @@ trait ClosureQuery {
sealed trait PathAttribute { sealed trait PathAttribute {
def toExpr : Expression = this match { def toExpr : Expression = this match {
case r : Relation => id(r.attributeName) case r : Relation => fieldId(r.attributeName)
case rr : ReverseRelation => id(s"${rr.typeName}->${rr.attributeName}") case rr : ReverseRelation => fieldId(s"${rr.typeName}->${rr.attributeName}")
} }
def toFieldName : String = this match { def toFieldName : String = this match {
...@@ -124,9 +124,9 @@ trait ClosureQuery { ...@@ -124,9 +124,9 @@ trait ClosureQuery {
def selectExpr(alias : String) : List[Expression] = { def selectExpr(alias : String) : List[Expression] = {
selectAttributes.map { _.map { a => selectAttributes.map { _.map { a =>
id(alias).field(a).as(s"${alias}_$a") fieldId(alias).field(a).as(s"${alias}_$a")
} }
}.getOrElse(List(id(alias))) }.getOrElse(List(fieldId(alias)))
} }
/** /**
...@@ -184,8 +184,8 @@ trait ClosureQuery { ...@@ -184,8 +184,8 @@ trait ClosureQuery {
* foreach resultRow * foreach resultRow
* for each Path entry * for each Path entry
* add an entry in the edges Map * add an entry in the edges Map
* add an entry for the Src AtlasVertex to the vertex Map * add an entry for the Src vertex to the vertex Map
* add an entry for the Dest AtlasVertex to the vertex Map * add an entry for the Dest vertex to the vertex Map
*/ */
res.rows.map(_.asInstanceOf[StructInstance]).foreach { r => res.rows.map(_.asInstanceOf[StructInstance]).foreach { r =>
...@@ -207,7 +207,7 @@ trait ClosureQuery { ...@@ -207,7 +207,7 @@ trait ClosureQuery {
} }
currVertex = nextVertex currVertex = nextVertex
} }
val AtlasVertex = r.get(TypeUtils.ResultWithPathStruct.resultAttrName) val vertex = r.get(TypeUtils.ResultWithPathStruct.resultAttrName)
vertices.put(id(srcVertex), vertexStruct(srcVertex, vertices.put(id(srcVertex), vertexStruct(srcVertex,
r.get(TypeUtils.ResultWithPathStruct.resultAttrName).asInstanceOf[ITypedStruct], r.get(TypeUtils.ResultWithPathStruct.resultAttrName).asInstanceOf[ITypedStruct],
s"${SRC_PREFIX}_")) s"${SRC_PREFIX}_"))
...@@ -237,7 +237,7 @@ trait SingleInstanceClosureQuery[T] extends ClosureQuery { ...@@ -237,7 +237,7 @@ trait SingleInstanceClosureQuery[T] extends ClosureQuery {
override def srcCondition(expr : Expression) : Expression = { override def srcCondition(expr : Expression) : Expression = {
expr.where( expr.where(
Expressions.id(attributeToSelectInstance).`=`(Expressions.literal(attributeTyp, instanceValue)) Expressions.fieldId(attributeToSelectInstance).`=`(Expressions.literal(attributeTyp, instanceValue))
) )
} }
} }
......
...@@ -387,7 +387,21 @@ object Expressions { ...@@ -387,7 +387,21 @@ object Expressions {
def _trait(name: String) = new TraitExpression(name) def _trait(name: String) = new TraitExpression(name)
case class IdExpression(name: String) extends Expression with LeafNode { object IdExpressionType extends Enumeration {
val Unresolved, NonType = Value;
class IdExpressionTypeValue(exprValue : Value) {
def isTypeAllowed = exprValue match {
case Unresolved => true
case _ => false
}
}
import scala.language.implicitConversions
implicit def value2ExprValue(exprValue: Value) = new IdExpressionTypeValue(exprValue)
}
case class IdExpression(name: String, exprType: IdExpressionType.Value) extends Expression with LeafNode {
override def toString = name override def toString = name
override lazy val resolved = false override lazy val resolved = false
...@@ -395,7 +409,16 @@ object Expressions { ...@@ -395,7 +409,16 @@ object Expressions {
override def dataType = throw new UnresolvedException(this, "id") override def dataType = throw new UnresolvedException(this, "id")
} }
def id(name: String) = new IdExpression(name) /**
* Creates an IdExpression whose allowed value type will be determined
* later.
*/
def id(name: String) = new IdExpression(name, IdExpressionType.Unresolved)
/**
* Creates an IdExpression whose value must resolve to a field name
*/
def fieldId(name: String) = new IdExpression(name, IdExpressionType.NonType)
case class UnresolvedFieldExpression(child: Expression, fieldName: String) extends Expression case class UnresolvedFieldExpression(child: Expression, fieldName: String) extends Expression
with UnaryNode { with UnaryNode {
......
...@@ -338,8 +338,26 @@ object QueryParser extends StandardTokenParsers with QueryKeywords with Expressi ...@@ -338,8 +338,26 @@ object QueryParser extends StandardTokenParsers with QueryKeywords with Expressi
} }
def identifier = rep1sep(ident, DOT) ^^ { l => l match { def identifier = rep1sep(ident, DOT) ^^ { l => l match {
/*
* We don't have enough context here to know what the id can be.
* Examples:
* Column isa PII - "Column" could be a field, type, or alias
* name = 'John' - "name" must be a field.
* Use generic id(), let type the be refined based on the context later.
*/
case h :: Nil => id(h) case h :: Nil => id(h)
case h :: t => {
/*
* Then left-most part of the identifier ("h") must be a can be either. However,
* Atlas does support struct attributes, whose fields must accessed through
* this syntax. Let the downstream processing figure out which case we're in.
*
* Examples:
* hive_table.name - here, hive_table must be a type
* sortCol.order - here, sortCol is a struct attribute, must resolve to a field.
*/
case h :: t => { //the left-most part of the identifier (h) can be
t.foldLeft(id(h).asInstanceOf[Expression])(_.field(_)) t.foldLeft(id(h).asInstanceOf[Expression])(_.field(_))
} }
} }
......
...@@ -42,9 +42,11 @@ object QueryProcessor { ...@@ -42,9 +42,11 @@ object QueryProcessor {
} }
def validate(e: Expression): Expression = { def validate(e: Expression): Expression = {
val e1 = e.transformUp(new Resolver())
e1.traverseUp { val e1 = e.transformUp(refineIdExpressionType);
val e2 = e1.transformUp(new Resolver(None,e1.namedExpressions))
e2.traverseUp {
case x: Expression if !x.resolved => case x: Expression if !x.resolved =>
throw new ExpressionException(x, s"Failed to resolved expression $x") throw new ExpressionException(x, s"Failed to resolved expression $x")
} }
...@@ -52,16 +54,65 @@ object QueryProcessor { ...@@ -52,16 +54,65 @@ object QueryProcessor {
/* /*
* trigger computation of dataType of expression tree * trigger computation of dataType of expression tree
*/ */
e1.dataType e2.dataType
/* /*
* ensure fieldReferences match the input expression's dataType * ensure fieldReferences match the input expression's dataType
*/ */
val e2 = e1.transformUp(FieldValidator) val e3 = e2.transformUp(FieldValidator)
val e3 = e2.transformUp(new Resolver()) val e4 = e3.transformUp(new Resolver(None,e3.namedExpressions))
e4.dataType
e4
}
val convertToFieldIdExpression : PartialFunction[Expression,Expression] = {
case IdExpression(name, IdExpressionType.Unresolved) => IdExpression(name, IdExpressionType.NonType);
}
//this function is called in a depth first manner on the expression tree to set the exprType in IdExpressions
//when we know them. Since Expression classes are immutable, in order to do this we need to create new instances
//of the case. The logic here enumerates the cases that have been identified where the given IdExpression
//cannot resolve to a class or trait. This is the case in any places where a field value must be used.
//For example, you cannot add two classes together or compare traits. Any IdExpressions in those contexts
//refer to unqualified attribute names. On a similar note, select clauses need to product an actual value.
//For example, in 'from DB select name' or 'from DB select name as n', name must be an attribute.
val refineIdExpressionType : PartialFunction[Expression,Expression] = {
//spit out the individual cases to minimize the object churn. Specifically, for ComparsionExpressions where neither
//child is an IdExpression, there is no need to create a new ComparsionExpression object since neither child will
//change. This applies to ArithmeticExpression as well.
case c@ComparisonExpression(symbol, l@IdExpression(_,IdExpressionType.Unresolved) , r@IdExpression(_,IdExpressionType.Unresolved)) => {
ComparisonExpression(symbol, convertToFieldIdExpression(l), convertToFieldIdExpression(r))
}
case c@ComparisonExpression(symbol, l@IdExpression(_,IdExpressionType.Unresolved) , r) => ComparisonExpression(symbol, convertToFieldIdExpression(l), r)
case c@ComparisonExpression(symbol, l, r@IdExpression(_,IdExpressionType.Unresolved)) => ComparisonExpression(symbol, l, convertToFieldIdExpression(r))
case e@ArithmeticExpression(symbol, l@IdExpression(_,IdExpressionType.Unresolved) , r@IdExpression(_,IdExpressionType.Unresolved)) => {
ArithmeticExpression(symbol, convertToFieldIdExpression(l), convertToFieldIdExpression(r))
}
case e@ArithmeticExpression(symbol, l@IdExpression(_,IdExpressionType.Unresolved) , r) => ArithmeticExpression(symbol, convertToFieldIdExpression(l), r)
case e@ArithmeticExpression(symbol, l, r@IdExpression(_,IdExpressionType.Unresolved)) => ArithmeticExpression(symbol, l, convertToFieldIdExpression(r))
e3.dataType case s@SelectExpression(child, selectList, forGroupBy) => {
var changed = false
val newSelectList = selectList.map {
expr => expr match {
case e@IdExpression(_,IdExpressionType.Unresolved) => { changed=true; convertToFieldIdExpression(e) }
case AliasExpression(child@IdExpression(_,IdExpressionType.Unresolved), alias) => {changed=true; AliasExpression(convertToFieldIdExpression(child), alias)}
case x => x
}
}
if(changed) {
SelectExpression(child, newSelectList, forGroupBy)
}
else {
s
}
}
e3
} }
} }
...@@ -20,7 +20,8 @@ package org.apache.atlas.query ...@@ -20,7 +20,8 @@ package org.apache.atlas.query
import org.apache.atlas.query.Expressions._ import org.apache.atlas.query.Expressions._
import org.apache.atlas.typesystem.types.IDataType import org.apache.atlas.typesystem.types.IDataType
import org.apache.atlas.typesystem.types.TraitType
import org.apache.atlas.typesystem.types.ClassType
class Resolver(srcExpr: Option[Expression] = None, aliases: Map[String, Expression] = Map(), class Resolver(srcExpr: Option[Expression] = None, aliases: Map[String, Expression] = Map(),
connectClassExprToSrc: Boolean = false) connectClassExprToSrc: Boolean = false)
extends PartialFunction[Expression, Expression] { extends PartialFunction[Expression, Expression] {
...@@ -30,25 +31,37 @@ class Resolver(srcExpr: Option[Expression] = None, aliases: Map[String, Expressi ...@@ -30,25 +31,37 @@ class Resolver(srcExpr: Option[Expression] = None, aliases: Map[String, Expressi
def isDefinedAt(x: Expression) = true def isDefinedAt(x: Expression) = true
def apply(e: Expression): Expression = e match { def apply(e: Expression): Expression = e match {
case idE@IdExpression(name) => { case idE@IdExpression(name, exprType) => {
val backExpr = aliases.get(name) val backExpr = aliases.get(name)
if (backExpr.isDefined) { if (backExpr.isDefined) {
if(backExpr.get.resolved) {
return new BackReference(name, backExpr.get, None) return new BackReference(name, backExpr.get, None)
} }
else {
//replace once resolved
return idE;
}
}
if (srcExpr.isDefined) { if (srcExpr.isDefined) {
val fInfo = resolveReference(srcExpr.get.dataType, name) val fInfo = resolveReference(srcExpr.get.dataType, name)
if (fInfo.isDefined) { if (fInfo.isDefined) {
return new FieldExpression(name, fInfo.get, None) return new FieldExpression(name, fInfo.get, None)
} }
} }
val cType = resolveAsClassType(name)
if (cType.isDefined) { if(exprType.isTypeAllowed) {
val dt = resolveAsDataType(name);
if(dt.isDefined) {
if(dt.get.isInstanceOf[ClassType]) {
return new ClassExpression(name) return new ClassExpression(name)
} }
val tType = resolveAsTraitType(name) if(dt.get.isInstanceOf[TraitType]) {
if (tType.isDefined) {
return new TraitExpression(name) return new TraitExpression(name)
} }
}
}
idE idE
} }
case ce@ClassExpression(clsName) if connectClassExprToSrc && srcExpr.isDefined => { case ce@ClassExpression(clsName) if connectClassExprToSrc && srcExpr.isDefined => {
......
...@@ -252,6 +252,15 @@ object TypeUtils { ...@@ -252,6 +252,15 @@ object TypeUtils {
None None
} }
def resolveAsDataType(id : String) : Option[IDataType[_]] = {
try {
Some(typSystem.getDataType(id))
} catch {
case _ : AtlasException => None
}
}
def resolveAsClassType(id : String) : Option[ClassType] = { def resolveAsClassType(id : String) : Option[ClassType] = {
try { try {
Some(typSystem.getDataType(classOf[ClassType], id)) Some(typSystem.getDataType(classOf[ClassType], id))
......
...@@ -387,6 +387,7 @@ public class GraphBackedDiscoveryServiceTest extends BaseRepositoryTest { ...@@ -387,6 +387,7 @@ public class GraphBackedDiscoveryServiceTest extends BaseRepositoryTest {
@DataProvider(name = "dslQueriesProvider") @DataProvider(name = "dslQueriesProvider")
private Object[][] createDSLQueries() { private Object[][] createDSLQueries() {
return new Object[][]{ return new Object[][]{
{"hive_db as inst where inst.name=\"Reporting\" select inst as id, inst.name", 1},
{"from hive_db as h select h as id", 3}, {"from hive_db as h select h as id", 3},
{"from hive_db", 3}, {"from hive_db", 3},
{"hive_db", 3}, {"hive_db", 3},
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.atlas.query;
import static org.apache.atlas.typesystem.types.utils.TypesUtil.createClassTypeDef;
import static org.apache.atlas.typesystem.types.utils.TypesUtil.createRequiredAttrDef;
import java.util.HashSet;
import java.util.Set;
import org.apache.atlas.AtlasException;
import org.apache.atlas.typesystem.types.ClassType;
import org.apache.atlas.typesystem.types.DataTypes;
import org.apache.atlas.typesystem.types.DataTypes.TypeCategory;
import org.apache.atlas.typesystem.types.HierarchicalTypeDefinition;
import org.apache.atlas.typesystem.types.IDataType;
import org.apache.atlas.typesystem.types.TypeSystem;
import org.apache.atlas.typesystem.types.cache.DefaultTypeCache;
import org.testng.annotations.Test;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.assertFalse;
import com.google.common.collect.ImmutableSet;
import scala.util.Either;
import scala.util.parsing.combinator.Parsers;
/**
* Tests the logic for skipping type cache lookup for things that
* cannot be types.
*
*/
public class QueryProcessorTest {
@Test
public void testAliasesNotTreatedAsTypes() throws Exception {
ValidatingTypeCache tc = findTypeLookupsDuringQueryParsing("hive_db as inst where inst.name=\"Reporting\" select inst as id, inst.name");
assertTrue(tc.wasTypeRequested("hive_db"));
assertFalse(tc.wasTypeRequested("inst"));
assertFalse(tc.wasTypeRequested("name"));
}
@Test
public void testFieldInComparisionNotTreatedAsType() throws Exception {
//test when the IdExpression is on the left, on the right, and on both sides of the ComparsionExpression
ValidatingTypeCache tc = findTypeLookupsDuringQueryParsing("hive_db where name=\"Reporting\" or \"Reporting\" = name or name=name");
assertTrue(tc.wasTypeRequested("hive_db"));
assertFalse(tc.wasTypeRequested("name"));
}
@Test
public void testFieldInArithmeticExprNotTreatedAsType() throws Exception {
//test when the IdExpression is on the left, on the right, and on both sides of the ArithmeticExpression
ValidatingTypeCache tc = findTypeLookupsDuringQueryParsing("hive_db where (tableCount + 3) > (tableCount + tableCount) select (3 + tableCount) as updatedCount");
assertTrue(tc.wasTypeRequested("hive_db"));
assertFalse(tc.wasTypeRequested("tableCount"));
assertFalse(tc.wasTypeRequested("updatedCount"));
}
@Test
public void testFieldInSelectListWithAlasNotTreatedAsType() throws Exception {
ValidatingTypeCache tc = findTypeLookupsDuringQueryParsing("hive_db select name as theName");
assertTrue(tc.wasTypeRequested("hive_db"));
assertFalse(tc.wasTypeRequested("theName"));
assertFalse(tc.wasTypeRequested("name"));
}
@Test
public void testFieldInSelectListNotTreatedAsType() throws Exception {
ValidatingTypeCache tc = findTypeLookupsDuringQueryParsing("hive_db select name");
assertTrue(tc.wasTypeRequested("hive_db"));
assertFalse(tc.wasTypeRequested("name"));
}
private ValidatingTypeCache findTypeLookupsDuringQueryParsing(String query) throws AtlasException {
TypeSystem typeSystem = TypeSystem.getInstance();
ValidatingTypeCache result = new ValidatingTypeCache();
typeSystem.setTypeCache(result);
typeSystem.reset();
HierarchicalTypeDefinition<ClassType> hiveTypeDef = createClassTypeDef("hive_db", "", ImmutableSet.<String>of(),
createRequiredAttrDef("name", DataTypes.STRING_TYPE),
createRequiredAttrDef("tableCount", DataTypes.INT_TYPE)
);
typeSystem.defineClassType(hiveTypeDef);
Either<Parsers.NoSuccess, Expressions.Expression> either = QueryParser.apply(query, null);
Expressions.Expression expression = either.right().get();
QueryProcessor.validate(expression);
return result;
}
private static class ValidatingTypeCache extends DefaultTypeCache {
private Set<String> typesRequested = new HashSet<>();
@Override
public boolean has(String typeName) throws AtlasException {
typesRequested.add(typeName);
return super.has(typeName);
}
@Override
public boolean has(TypeCategory typeCategory, String typeName) throws AtlasException {
typesRequested.add(typeName);
return super.has(typeCategory, typeName);
}
@Override
public IDataType get(String typeName) throws AtlasException {
typesRequested.add(typeName);
return super.get(typeName);
}
@Override
public IDataType get(TypeCategory typeCategory, String typeName) throws AtlasException {
typesRequested.add(typeName);
return super.get(typeCategory, typeName);
}
public boolean wasTypeRequested(String name) {
return typesRequested.contains(name);
}
}
}
...@@ -149,17 +149,13 @@ public class TypeSystem { ...@@ -149,17 +149,13 @@ public class TypeSystem {
return coreTypes.containsKey(typeName); return coreTypes.containsKey(typeName);
} }
public <T> T getDataType(Class<T> cls, String name) throws AtlasException { public IDataType getDataType(String name) throws AtlasException {
if (isCoreType(name)) { if (isCoreType(name)) {
return cls.cast(coreTypes.get(name)); return coreTypes.get(name);
} }
if (typeCache.has(name)) { if (typeCache.has(name)) {
try { return typeCache.get(name);
return cls.cast(typeCache.get(name));
} catch (ClassCastException cce) {
throw new AtlasException(cce);
}
} }
/* /*
...@@ -167,8 +163,8 @@ public class TypeSystem { ...@@ -167,8 +163,8 @@ public class TypeSystem {
*/ */
String arrElemType = TypeUtils.parseAsArrayType(name); String arrElemType = TypeUtils.parseAsArrayType(name);
if (arrElemType != null) { if (arrElemType != null) {
IDataType dT = defineArrayType(getDataType(IDataType.class, arrElemType)); IDataType dT = defineArrayType(getDataType(arrElemType));
return cls.cast(dT); return dT;
} }
/* /*
...@@ -177,8 +173,8 @@ public class TypeSystem { ...@@ -177,8 +173,8 @@ public class TypeSystem {
String[] mapType = TypeUtils.parseAsMapType(name); String[] mapType = TypeUtils.parseAsMapType(name);
if (mapType != null) { if (mapType != null) {
IDataType dT = IDataType dT =
defineMapType(getDataType(IDataType.class, mapType[0]), getDataType(IDataType.class, mapType[1])); defineMapType(getDataType(mapType[0]), getDataType(mapType[1]));
return cls.cast(dT); return dT;
} }
/* /*
...@@ -186,12 +182,22 @@ public class TypeSystem { ...@@ -186,12 +182,22 @@ public class TypeSystem {
*/ */
IDataType dT = typeCache.onTypeFault(name); IDataType dT = typeCache.onTypeFault(name);
if (dT != null) { if (dT != null) {
return cls.cast(dT); return dT;
} }
throw new TypeNotFoundException(String.format("Unknown datatype: %s", name)); throw new TypeNotFoundException(String.format("Unknown datatype: %s", name));
} }
public <T extends IDataType> T getDataType(Class<T> cls, String name) throws AtlasException {
try {
IDataType dt = getDataType(name);
return cls.cast(dt);
} catch (ClassCastException cce) {
throw new AtlasException(cce);
}
}
public StructType defineStructType(String name, boolean errorIfExists, AttributeDefinition... attrDefs) public StructType defineStructType(String name, boolean errorIfExists, AttributeDefinition... attrDefs)
throws AtlasException { throws AtlasException {
return defineStructType(name, null, errorIfExists, attrDefs); return defineStructType(name, null, errorIfExists, attrDefs);
...@@ -636,14 +642,11 @@ public class TypeSystem { ...@@ -636,14 +642,11 @@ public class TypeSystem {
//get from transient types. Else, from main type system //get from transient types. Else, from main type system
@Override @Override
public <T> T getDataType(Class<T> cls, String name) throws AtlasException { public IDataType getDataType(String name) throws AtlasException {
if (transientTypes != null) { if (transientTypes != null) {
if (transientTypes.containsKey(name)) { if (transientTypes.containsKey(name)) {
try { return transientTypes.get(name);
return cls.cast(transientTypes.get(name));
} catch (ClassCastException cce) {
throw new AtlasException(cce);
}
} }
/* /*
...@@ -652,7 +655,7 @@ public class TypeSystem { ...@@ -652,7 +655,7 @@ public class TypeSystem {
String arrElemType = TypeUtils.parseAsArrayType(name); String arrElemType = TypeUtils.parseAsArrayType(name);
if (arrElemType != null) { if (arrElemType != null) {
IDataType dT = defineArrayType(getDataType(IDataType.class, arrElemType)); IDataType dT = defineArrayType(getDataType(IDataType.class, arrElemType));
return cls.cast(dT); return dT;
} }
/* /*
...@@ -662,11 +665,11 @@ public class TypeSystem { ...@@ -662,11 +665,11 @@ public class TypeSystem {
if (mapType != null) { if (mapType != null) {
IDataType dT = IDataType dT =
defineMapType(getDataType(IDataType.class, mapType[0]), getDataType(IDataType.class, mapType[1])); defineMapType(getDataType(IDataType.class, mapType[0]), getDataType(IDataType.class, mapType[1]));
return cls.cast(dT); return dT;
} }
} }
return TypeSystem.this.getDataType(cls, name); return TypeSystem.this.getDataType(name);
} }
@Override @Override
......
...@@ -411,7 +411,8 @@ public abstract class BaseResourceIT { ...@@ -411,7 +411,8 @@ public abstract class BaseResourceIT {
} }
protected String randomString() { protected String randomString() {
return RandomStringUtils.randomAlphanumeric(10); //names cannot start with a digit
return RandomStringUtils.randomAlphabetic(1) + RandomStringUtils.randomAlphanumeric(9);
} }
protected Referenceable createHiveTableInstanceV1(String dbName, String tableName, Id dbId) throws Exception { protected Referenceable createHiveTableInstanceV1(String dbName, String tableName, Id dbId) throws Exception {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment