You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
380 lines
11 KiB
380 lines
11 KiB
package gorm |
|
|
|
import ( |
|
"errors" |
|
"fmt" |
|
"reflect" |
|
"strconv" |
|
"strings" |
|
) |
|
|
|
// preloadCallback used to preload associations |
|
func preloadCallback(scope *Scope) { |
|
|
|
if _, ok := scope.Get("gorm:auto_preload"); ok { |
|
autoPreload(scope) |
|
} |
|
|
|
if scope.Search.preload == nil || scope.HasError() { |
|
return |
|
} |
|
|
|
var ( |
|
preloadedMap = map[string]bool{} |
|
fields = scope.Fields() |
|
) |
|
|
|
for _, preload := range scope.Search.preload { |
|
var ( |
|
preloadFields = strings.Split(preload.schema, ".") |
|
currentScope = scope |
|
currentFields = fields |
|
) |
|
|
|
for idx, preloadField := range preloadFields { |
|
var currentPreloadConditions []interface{} |
|
|
|
if currentScope == nil { |
|
continue |
|
} |
|
|
|
// if not preloaded |
|
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { |
|
|
|
// assign search conditions to last preload |
|
if idx == len(preloadFields)-1 { |
|
currentPreloadConditions = preload.conditions |
|
} |
|
|
|
for _, field := range currentFields { |
|
if field.Name != preloadField || field.Relationship == nil { |
|
continue |
|
} |
|
|
|
switch field.Relationship.Kind { |
|
case "has_one": |
|
currentScope.handleHasOnePreload(field, currentPreloadConditions) |
|
case "has_many": |
|
currentScope.handleHasManyPreload(field, currentPreloadConditions) |
|
case "belongs_to": |
|
currentScope.handleBelongsToPreload(field, currentPreloadConditions) |
|
case "many_to_many": |
|
currentScope.handleManyToManyPreload(field, currentPreloadConditions) |
|
default: |
|
scope.Err(errors.New("unsupported relation")) |
|
} |
|
|
|
preloadedMap[preloadKey] = true |
|
break |
|
} |
|
|
|
if !preloadedMap[preloadKey] { |
|
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) |
|
return |
|
} |
|
} |
|
|
|
// preload next level |
|
if idx < len(preloadFields)-1 { |
|
currentScope = currentScope.getColumnAsScope(preloadField) |
|
if currentScope != nil { |
|
currentFields = currentScope.Fields() |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
func autoPreload(scope *Scope) { |
|
for _, field := range scope.Fields() { |
|
if field.Relationship == nil { |
|
continue |
|
} |
|
|
|
if val, ok := field.TagSettings["PRELOAD"]; ok { |
|
if preload, err := strconv.ParseBool(val); err != nil { |
|
scope.Err(errors.New("invalid preload option")) |
|
return |
|
} else if !preload { |
|
continue |
|
} |
|
} |
|
|
|
scope.Search.Preload(field.Name) |
|
} |
|
} |
|
|
|
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { |
|
var ( |
|
preloadDB = scope.NewDB() |
|
preloadConditions []interface{} |
|
) |
|
|
|
for _, condition := range conditions { |
|
if scopes, ok := condition.(func(*DB) *DB); ok { |
|
preloadDB = scopes(preloadDB) |
|
} else { |
|
preloadConditions = append(preloadConditions, condition) |
|
} |
|
} |
|
|
|
return preloadDB, preloadConditions |
|
} |
|
|
|
// handleHasOnePreload used to preload has one associations |
|
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { |
|
relation := field.Relationship |
|
|
|
// get relations's primary keys |
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) |
|
if len(primaryKeys) == 0 { |
|
return |
|
} |
|
|
|
// preload conditions |
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) |
|
|
|
// find relations |
|
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) |
|
values := toQueryValues(primaryKeys) |
|
if relation.PolymorphicType != "" { |
|
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) |
|
values = append(values, relation.PolymorphicValue) |
|
} |
|
|
|
results := makeSlice(field.Struct.Type) |
|
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) |
|
|
|
// assign find results |
|
var ( |
|
resultsValue = indirect(reflect.ValueOf(results)) |
|
indirectScopeValue = scope.IndirectValue() |
|
) |
|
|
|
if indirectScopeValue.Kind() == reflect.Slice { |
|
for j := 0; j < indirectScopeValue.Len(); j++ { |
|
for i := 0; i < resultsValue.Len(); i++ { |
|
result := resultsValue.Index(i) |
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames) |
|
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { |
|
indirectValue.FieldByName(field.Name).Set(result) |
|
break |
|
} |
|
} |
|
} |
|
} else { |
|
for i := 0; i < resultsValue.Len(); i++ { |
|
result := resultsValue.Index(i) |
|
scope.Err(field.Set(result)) |
|
} |
|
} |
|
} |
|
|
|
// handleHasManyPreload used to preload has many associations |
|
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { |
|
relation := field.Relationship |
|
|
|
// get relations's primary keys |
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) |
|
if len(primaryKeys) == 0 { |
|
return |
|
} |
|
|
|
// preload conditions |
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) |
|
|
|
// find relations |
|
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) |
|
values := toQueryValues(primaryKeys) |
|
if relation.PolymorphicType != "" { |
|
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) |
|
values = append(values, relation.PolymorphicValue) |
|
} |
|
|
|
results := makeSlice(field.Struct.Type) |
|
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) |
|
|
|
// assign find results |
|
var ( |
|
resultsValue = indirect(reflect.ValueOf(results)) |
|
indirectScopeValue = scope.IndirectValue() |
|
) |
|
|
|
if indirectScopeValue.Kind() == reflect.Slice { |
|
preloadMap := make(map[string][]reflect.Value) |
|
for i := 0; i < resultsValue.Len(); i++ { |
|
result := resultsValue.Index(i) |
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames) |
|
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) |
|
} |
|
|
|
for j := 0; j < indirectScopeValue.Len(); j++ { |
|
object := indirect(indirectScopeValue.Index(j)) |
|
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) |
|
f := object.FieldByName(field.Name) |
|
if results, ok := preloadMap[toString(objectRealValue)]; ok { |
|
f.Set(reflect.Append(f, results...)) |
|
} else { |
|
f.Set(reflect.MakeSlice(f.Type(), 0, 0)) |
|
} |
|
} |
|
} else { |
|
scope.Err(field.Set(resultsValue)) |
|
} |
|
} |
|
|
|
// handleBelongsToPreload used to preload belongs to associations |
|
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { |
|
relation := field.Relationship |
|
|
|
// preload conditions |
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) |
|
|
|
// get relations's primary keys |
|
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) |
|
if len(primaryKeys) == 0 { |
|
return |
|
} |
|
|
|
// find relations |
|
results := makeSlice(field.Struct.Type) |
|
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) |
|
|
|
// assign find results |
|
var ( |
|
resultsValue = indirect(reflect.ValueOf(results)) |
|
indirectScopeValue = scope.IndirectValue() |
|
) |
|
|
|
for i := 0; i < resultsValue.Len(); i++ { |
|
result := resultsValue.Index(i) |
|
if indirectScopeValue.Kind() == reflect.Slice { |
|
value := getValueFromFields(result, relation.AssociationForeignFieldNames) |
|
for j := 0; j < indirectScopeValue.Len(); j++ { |
|
object := indirect(indirectScopeValue.Index(j)) |
|
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { |
|
object.FieldByName(field.Name).Set(result) |
|
} |
|
} |
|
} else { |
|
scope.Err(field.Set(result)) |
|
} |
|
} |
|
} |
|
|
|
// handleManyToManyPreload used to preload many to many associations |
|
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { |
|
var ( |
|
relation = field.Relationship |
|
joinTableHandler = relation.JoinTableHandler |
|
fieldType = field.Struct.Type.Elem() |
|
foreignKeyValue interface{} |
|
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() |
|
linkHash = map[string][]reflect.Value{} |
|
isPtr bool |
|
) |
|
|
|
if fieldType.Kind() == reflect.Ptr { |
|
isPtr = true |
|
fieldType = fieldType.Elem() |
|
} |
|
|
|
var sourceKeys = []string{} |
|
for _, key := range joinTableHandler.SourceForeignKeys() { |
|
sourceKeys = append(sourceKeys, key.DBName) |
|
} |
|
|
|
// preload conditions |
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) |
|
|
|
// generate query with join table |
|
newScope := scope.New(reflect.New(fieldType).Interface()) |
|
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) |
|
|
|
if len(preloadDB.search.selects) == 0 { |
|
preloadDB = preloadDB.Select("*") |
|
} |
|
|
|
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) |
|
|
|
// preload inline conditions |
|
if len(preloadConditions) > 0 { |
|
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) |
|
} |
|
|
|
rows, err := preloadDB.Rows() |
|
|
|
if scope.Err(err) != nil { |
|
return |
|
} |
|
defer rows.Close() |
|
|
|
columns, _ := rows.Columns() |
|
for rows.Next() { |
|
var ( |
|
elem = reflect.New(fieldType).Elem() |
|
fields = scope.New(elem.Addr().Interface()).Fields() |
|
) |
|
|
|
// register foreign keys in join tables |
|
var joinTableFields []*Field |
|
for _, sourceKey := range sourceKeys { |
|
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) |
|
} |
|
|
|
scope.scan(rows, columns, append(fields, joinTableFields...)) |
|
|
|
var foreignKeys = make([]interface{}, len(sourceKeys)) |
|
// generate hashed forkey keys in join table |
|
for idx, joinTableField := range joinTableFields { |
|
if !joinTableField.Field.IsNil() { |
|
foreignKeys[idx] = joinTableField.Field.Elem().Interface() |
|
} |
|
} |
|
hashedSourceKeys := toString(foreignKeys) |
|
|
|
if isPtr { |
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) |
|
} else { |
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) |
|
} |
|
} |
|
|
|
if err := rows.Err(); err != nil { |
|
scope.Err(err) |
|
} |
|
|
|
// assign find results |
|
var ( |
|
indirectScopeValue = scope.IndirectValue() |
|
fieldsSourceMap = map[string][]reflect.Value{} |
|
foreignFieldNames = []string{} |
|
) |
|
|
|
for _, dbName := range relation.ForeignFieldNames { |
|
if field, ok := scope.FieldByName(dbName); ok { |
|
foreignFieldNames = append(foreignFieldNames, field.Name) |
|
} |
|
} |
|
|
|
if indirectScopeValue.Kind() == reflect.Slice { |
|
for j := 0; j < indirectScopeValue.Len(); j++ { |
|
object := indirect(indirectScopeValue.Index(j)) |
|
key := toString(getValueFromFields(object, foreignFieldNames)) |
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) |
|
} |
|
} else if indirectScopeValue.IsValid() { |
|
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) |
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) |
|
} |
|
for source, link := range linkHash { |
|
for i, field := range fieldsSourceMap[source] { |
|
//If not 0 this means Value is a pointer and we already added preloaded models to it |
|
if fieldsSourceMap[source][i].Len() != 0 { |
|
continue |
|
} |
|
field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) |
|
} |
|
|
|
} |
|
}
|
|
|