diff --git a/dbHandler.go b/dbHandler.go index 3f71df8..71956ac 100644 --- a/dbHandler.go +++ b/dbHandler.go @@ -129,14 +129,22 @@ func (dH *DBHandler) AddNewColum(model any) error { // // Returns: // - error: Any query error or “not found” message. -func (dH *DBHandler) GetById(model any, id uint) error { +func (dH *DBHandler) GetById(model any, relation string, id uint) (err error) { dH.logger.Debug("getById", "find id "+fmt.Sprint(id)) if id == 0 { + if relation != "" { + return dH.db.Preload(relation).Find(model).Error + } return dH.db.Find(model).Error } - err := dH.db.First(model, id).Error + if relation != "" { + err = dH.db.Preload(relation).First(model, id).Error + } else { + err = dH.db.First(model, id).Error + } + if errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("no record found for id: %v", id) } else if err != nil { @@ -163,14 +171,20 @@ func (dH *DBHandler) GetById(model any, id uint) error { // // Returns: // - error: Any database query error. -func (dH *DBHandler) GetByKey(model any, key string, value any, LikeSearch bool) error { +func (dH *DBHandler) GetByKey(model any, relation string, key string, value any, LikeSearch bool) error { if LikeSearch { value = strings.ReplaceAll(fmt.Sprint(value), "*", "%") dH.logger.Debug("getByKey", "find like key "+key+" value "+fmt.Sprint(value)) + if relation != "" { + return dH.db.Preload(relation).Where(key+" LIKE ?", value).Delete(model).Error + } return dH.db.Where(key+" LIKE ?", value).Find(model).Error } dH.logger.Debug("getByKey", "find equal key "+key+" value "+fmt.Sprint(value)) + if relation != "" { + return dH.db.Preload(relation).Where(key+" = ?", value).Delete(model).Error + } return dH.db.Find(model, key+" = ?", value).Error } @@ -191,7 +205,7 @@ func (dH *DBHandler) GetByKey(model any, key string, value any, LikeSearch bool) // // Returns: // - error: If the model is invalid or query/update fails. -func (dH *DBHandler) UpdateValuesById(model any, id uint) error { +func (dH *DBHandler) UpdateValuesById(model any, relation string, id uint) error { dH.logger.Debug("updateValuesById", "model"+fmt.Sprint(model)) modelType := reflect.TypeOf(model) if modelType.Kind() != reflect.Ptr { @@ -199,9 +213,12 @@ func (dH *DBHandler) UpdateValuesById(model any, id uint) error { } lookUpModel := reflect.New(modelType.Elem()).Interface() - if err := dH.GetById(lookUpModel, id); err != nil { + if err := dH.GetById(lookUpModel, relation, id); err != nil { return err } + if relation != "" { + return dH.db.Preload(relation).Model(lookUpModel).Updates(model).Error + } return dH.db.Model(lookUpModel).Updates(model).Error } @@ -223,7 +240,7 @@ func (dH *DBHandler) UpdateValuesById(model any, id uint) error { // // Returns: // - error: Any query or update error. -func (dH *DBHandler) UpdateValuesByKey(model any, key string, value any) error { +func (dH *DBHandler) UpdateValuesByKey(model any, relation string, key string, value any) error { dH.logger.Debug("updateValuesByKey", "model"+fmt.Sprint(model)) modelType := reflect.TypeOf(model) if modelType.Kind() != reflect.Ptr { @@ -231,9 +248,12 @@ func (dH *DBHandler) UpdateValuesByKey(model any, key string, value any) error { } lookUpModel := reflect.New(modelType.Elem()).Interface() - if err := dH.GetByKey(lookUpModel, key, value, false); err != nil { + if err := dH.GetByKey(lookUpModel, "", key, value, false); err != nil { return err } + if relation != "" { + return dH.db.Preload(relation).Model(lookUpModel).Updates(model).Error + } return dH.db.Model(lookUpModel).Updates(model).Error } @@ -322,23 +342,6 @@ func (dH *DBHandler) Exists(model any, key string, value any, likeSearch bool) ( return tx.RowsAffected > 0 } -func (dH *DBHandler) GetByIdWithRelation(model any, relation string, id uint) error { - if id == 0 { - return dH.db.Preload(relation).Find(model).Error - } - - // Load member AND its events - err := dH.db.Preload(relation).First(model, id).Error - - if errors.Is(err, gorm.ErrRecordNotFound) { - return fmt.Errorf("no record found for id: %v", id) - } else if err != nil { - return fmt.Errorf("query failed: %w", err) - } - return nil -} - func (dH *DBHandler) AddRelation(model, relation any, relationName string) error { return dH.db.Model(model).Association(relationName).Append(relation) - } diff --git a/db_test.go b/db_test.go index 826c4f5..f5103e1 100644 --- a/db_test.go +++ b/db_test.go @@ -12,10 +12,12 @@ type Event struct { } type Member struct { - Id int `gorm:"primaryKey" json:"id"` - FirstName string `gorm:"column:firstName" json:"firstName,omitempty"` - LastName string `gorm:"column:lastName" json:"lastName,omitempty"` - Events []*Event `gorm:"many2many:member_events;" json:"events"` + Id int `gorm:"primaryKey" json:"id"` + FirstName string `gorm:"column:firstName" json:"firstName,omitempty"` + LastName string `gorm:"column:lastName" json:"lastName,omitempty"` + ResponsiblePersonID *int `json:"responsiblePersonId"` + ResponsiblePerson *Member `gorm:"foreignKey:ResponsiblePersonID" json:"responsiblePerson"` + Events []*Event `gorm:"many2many:member_events;" json:"events"` } func TestDbHandler(t *testing.T) { @@ -33,21 +35,23 @@ func TestDbHandler(t *testing.T) { t.Fatal(err) } - member := &Member{FirstName: "adrian", LastName: "zuercher"} - dbHandler.AddNewColum(member) - event := &Event{ - Name: "testEvent", + membersIN := []*Member{{FirstName: "adrian", LastName: "zuercher"}, {FirstName: "adichild", LastName: "zuercher"}} + for _, member := range membersIN { + dbHandler.AddNewColum(member) + event := &Event{ + Name: "testEvent", + } + dbHandler.AddNewColum(event) + dbHandler.AddRelation(event, member, "Attendees") } - dbHandler.AddNewColum(event) - dbHandler.AddRelation(event, member, "Attendees") var members []Member - if err := dbHandler.GetById(&members, 0); err != nil { + if err := dbHandler.GetById(&members, "", 0); err != nil { t.Fatal(err) } t.Log(members) - if err := dbHandler.GetByIdWithRelation(&members, "Events", 0); err != nil { + if err := dbHandler.GetById(&members, "Events", 0); err != nil { t.Fatal(err) } t.Log(members) @@ -61,7 +65,7 @@ func TestDbHandler(t *testing.T) { var events []Event - if err := dbHandler.GetByIdWithRelation(&events, "Attendees", 0); err != nil { + if err := dbHandler.GetById(&events, "Attendees", 0); err != nil { t.Fatal(err) } @@ -72,6 +76,27 @@ func TestDbHandler(t *testing.T) { } } + //add responsible + mem1 := &Member{} + if err := dbHandler.GetById(mem1, "", 2); err != nil { + t.Fatal(err) + } + + mem2 := &Member{} + if err := dbHandler.GetById(mem2, "", 1); err != nil { + t.Fatal(err) + } + + mem1.ResponsiblePerson = mem2 + dbHandler.UpdateValuesById(mem1, "", uint(mem1.Id)) + + if err := dbHandler.GetById(mem1, "", 2); err != nil { + t.Fatal(err) + } + + t.Log(mem1) + t.Log(mem1.ResponsiblePerson) + if err := dbHandler.Close(); err != nil { t.Fatal(err) }