session_insert.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "github.com/go-xorm/core"
  12. )
  13. // Insert insert one or more beans
  14. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  15. var affected int64
  16. var err error
  17. if session.IsAutoClose {
  18. defer session.Close()
  19. }
  20. defer session.resetStatement()
  21. for _, bean := range beans {
  22. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  23. if sliceValue.Kind() == reflect.Slice {
  24. size := sliceValue.Len()
  25. if size > 0 {
  26. if session.Engine.SupportInsertMany() {
  27. cnt, err := session.innerInsertMulti(bean)
  28. if err != nil {
  29. return affected, err
  30. }
  31. affected += cnt
  32. } else {
  33. for i := 0; i < size; i++ {
  34. cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
  35. if err != nil {
  36. return affected, err
  37. }
  38. affected += cnt
  39. }
  40. }
  41. }
  42. } else {
  43. cnt, err := session.innerInsert(bean)
  44. if err != nil {
  45. return affected, err
  46. }
  47. affected += cnt
  48. }
  49. }
  50. return affected, err
  51. }
  52. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  53. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  54. if sliceValue.Kind() != reflect.Slice {
  55. return 0, errors.New("needs a pointer to a slice")
  56. }
  57. if sliceValue.Len() <= 0 {
  58. return 0, errors.New("could not insert a empty slice")
  59. }
  60. if err := session.Statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
  61. return 0, err
  62. }
  63. if len(session.Statement.TableName()) <= 0 {
  64. return 0, ErrTableNotFound
  65. }
  66. table := session.Statement.RefTable
  67. size := sliceValue.Len()
  68. var colNames []string
  69. var colMultiPlaces []string
  70. var args []interface{}
  71. var cols []*core.Column
  72. for i := 0; i < size; i++ {
  73. v := sliceValue.Index(i)
  74. vv := reflect.Indirect(v)
  75. elemValue := v.Interface()
  76. var colPlaces []string
  77. // handle BeforeInsertProcessor
  78. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  79. for _, closure := range session.beforeClosures {
  80. closure(elemValue)
  81. }
  82. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  83. processor.BeforeInsert()
  84. }
  85. // --
  86. if i == 0 {
  87. for _, col := range table.Columns() {
  88. ptrFieldValue, err := col.ValueOfV(&vv)
  89. if err != nil {
  90. return 0, err
  91. }
  92. fieldValue := *ptrFieldValue
  93. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  94. continue
  95. }
  96. if col.MapType == core.ONLYFROMDB {
  97. continue
  98. }
  99. if col.IsDeleted {
  100. continue
  101. }
  102. if session.Statement.ColumnStr != "" {
  103. if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
  104. continue
  105. }
  106. }
  107. if session.Statement.OmitStr != "" {
  108. if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
  109. continue
  110. }
  111. }
  112. if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
  113. val, t := session.Engine.NowTime2(col.SQLType.Name)
  114. args = append(args, val)
  115. var colName = col.Name
  116. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  117. col := table.GetColumn(colName)
  118. setColumnTime(bean, col, t)
  119. })
  120. } else if col.IsVersion && session.Statement.checkVersion {
  121. args = append(args, 1)
  122. var colName = col.Name
  123. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  124. col := table.GetColumn(colName)
  125. setColumnInt(bean, col, 1)
  126. })
  127. } else {
  128. arg, err := session.value2Interface(col, fieldValue)
  129. if err != nil {
  130. return 0, err
  131. }
  132. args = append(args, arg)
  133. }
  134. colNames = append(colNames, col.Name)
  135. cols = append(cols, col)
  136. colPlaces = append(colPlaces, "?")
  137. }
  138. } else {
  139. for _, col := range cols {
  140. ptrFieldValue, err := col.ValueOfV(&vv)
  141. if err != nil {
  142. return 0, err
  143. }
  144. fieldValue := *ptrFieldValue
  145. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  146. continue
  147. }
  148. if col.MapType == core.ONLYFROMDB {
  149. continue
  150. }
  151. if col.IsDeleted {
  152. continue
  153. }
  154. if session.Statement.ColumnStr != "" {
  155. if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
  156. continue
  157. }
  158. }
  159. if session.Statement.OmitStr != "" {
  160. if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
  161. continue
  162. }
  163. }
  164. if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
  165. val, t := session.Engine.NowTime2(col.SQLType.Name)
  166. args = append(args, val)
  167. var colName = col.Name
  168. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  169. col := table.GetColumn(colName)
  170. setColumnTime(bean, col, t)
  171. })
  172. } else if col.IsVersion && session.Statement.checkVersion {
  173. args = append(args, 1)
  174. var colName = col.Name
  175. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  176. col := table.GetColumn(colName)
  177. setColumnInt(bean, col, 1)
  178. })
  179. } else {
  180. arg, err := session.value2Interface(col, fieldValue)
  181. if err != nil {
  182. return 0, err
  183. }
  184. args = append(args, arg)
  185. }
  186. colPlaces = append(colPlaces, "?")
  187. }
  188. }
  189. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  190. }
  191. cleanupProcessorsClosures(&session.beforeClosures)
  192. var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
  193. var statement string
  194. if session.Engine.dialect.DBType() == core.ORACLE {
  195. sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
  196. temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
  197. session.Engine.Quote(session.Statement.TableName()),
  198. session.Engine.QuoteStr(),
  199. strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
  200. session.Engine.QuoteStr())
  201. statement = fmt.Sprintf(sql,
  202. session.Engine.Quote(session.Statement.TableName()),
  203. session.Engine.QuoteStr(),
  204. strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
  205. session.Engine.QuoteStr(),
  206. strings.Join(colMultiPlaces, temp))
  207. } else {
  208. statement = fmt.Sprintf(sql,
  209. session.Engine.Quote(session.Statement.TableName()),
  210. session.Engine.QuoteStr(),
  211. strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
  212. session.Engine.QuoteStr(),
  213. strings.Join(colMultiPlaces, "),("))
  214. }
  215. res, err := session.exec(statement, args...)
  216. if err != nil {
  217. return 0, err
  218. }
  219. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  220. session.cacheInsert(session.Statement.TableName())
  221. }
  222. lenAfterClosures := len(session.afterClosures)
  223. for i := 0; i < size; i++ {
  224. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  225. // handle AfterInsertProcessor
  226. if session.IsAutoCommit {
  227. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  228. for _, closure := range session.afterClosures {
  229. closure(elemValue)
  230. }
  231. if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  232. processor.AfterInsert()
  233. }
  234. } else {
  235. if lenAfterClosures > 0 {
  236. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  237. *value = append(*value, session.afterClosures...)
  238. } else {
  239. afterClosures := make([]func(interface{}), lenAfterClosures)
  240. copy(afterClosures, session.afterClosures)
  241. session.afterInsertBeans[elemValue] = &afterClosures
  242. }
  243. } else {
  244. if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  245. session.afterInsertBeans[elemValue] = nil
  246. }
  247. }
  248. }
  249. }
  250. cleanupProcessorsClosures(&session.afterClosures)
  251. return res.RowsAffected()
  252. }
  253. // InsertMulti insert multiple records
  254. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  255. defer session.resetStatement()
  256. if session.IsAutoClose {
  257. defer session.Close()
  258. }
  259. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  260. if sliceValue.Kind() != reflect.Slice {
  261. return 0, ErrParamsType
  262. }
  263. if sliceValue.Len() <= 0 {
  264. return 0, nil
  265. }
  266. return session.innerInsertMulti(rowsSlicePtr)
  267. }
  268. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  269. if err := session.Statement.setRefValue(rValue(bean)); err != nil {
  270. return 0, err
  271. }
  272. if len(session.Statement.TableName()) <= 0 {
  273. return 0, ErrTableNotFound
  274. }
  275. table := session.Statement.RefTable
  276. // handle BeforeInsertProcessor
  277. for _, closure := range session.beforeClosures {
  278. closure(bean)
  279. }
  280. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  281. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  282. processor.BeforeInsert()
  283. }
  284. // --
  285. colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false)
  286. if err != nil {
  287. return 0, err
  288. }
  289. // insert expr columns, override if exists
  290. exprColumns := session.Statement.getExpr()
  291. exprColVals := make([]string, 0, len(exprColumns))
  292. for _, v := range exprColumns {
  293. // remove the expr columns
  294. for i, colName := range colNames {
  295. if colName == v.colName {
  296. colNames = append(colNames[:i], colNames[i+1:]...)
  297. args = append(args[:i], args[i+1:]...)
  298. }
  299. }
  300. // append expr column to the end
  301. colNames = append(colNames, v.colName)
  302. exprColVals = append(exprColVals, v.expr)
  303. }
  304. colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
  305. if len(exprColVals) > 0 {
  306. colPlaces = colPlaces + strings.Join(exprColVals, ", ")
  307. } else {
  308. if len(colPlaces) > 0 {
  309. colPlaces = colPlaces[0 : len(colPlaces)-2]
  310. }
  311. }
  312. var sqlStr string
  313. if len(colPlaces) > 0 {
  314. sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
  315. session.Engine.Quote(session.Statement.TableName()),
  316. session.Engine.QuoteStr(),
  317. strings.Join(colNames, session.Engine.Quote(", ")),
  318. session.Engine.QuoteStr(),
  319. colPlaces)
  320. } else {
  321. if session.Engine.dialect.DBType() == core.MYSQL {
  322. sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.Engine.Quote(session.Statement.TableName()))
  323. } else {
  324. sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.Engine.Quote(session.Statement.TableName()))
  325. }
  326. }
  327. handleAfterInsertProcessorFunc := func(bean interface{}) {
  328. if session.IsAutoCommit {
  329. for _, closure := range session.afterClosures {
  330. closure(bean)
  331. }
  332. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  333. processor.AfterInsert()
  334. }
  335. } else {
  336. lenAfterClosures := len(session.afterClosures)
  337. if lenAfterClosures > 0 {
  338. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  339. *value = append(*value, session.afterClosures...)
  340. } else {
  341. afterClosures := make([]func(interface{}), lenAfterClosures)
  342. copy(afterClosures, session.afterClosures)
  343. session.afterInsertBeans[bean] = &afterClosures
  344. }
  345. } else {
  346. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  347. session.afterInsertBeans[bean] = nil
  348. }
  349. }
  350. }
  351. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  352. }
  353. // for postgres, many of them didn't implement lastInsertId, so we should
  354. // implemented it ourself.
  355. if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
  356. res, err := session.query("select seq_atable.currval from dual", args...)
  357. if err != nil {
  358. return 0, err
  359. }
  360. handleAfterInsertProcessorFunc(bean)
  361. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  362. session.cacheInsert(session.Statement.TableName())
  363. }
  364. if table.Version != "" && session.Statement.checkVersion {
  365. verValue, err := table.VersionColumn().ValueOf(bean)
  366. if err != nil {
  367. session.Engine.logger.Error(err)
  368. } else if verValue.IsValid() && verValue.CanSet() {
  369. verValue.SetInt(1)
  370. }
  371. }
  372. if len(res) < 1 {
  373. return 0, errors.New("insert no error but not returned id")
  374. }
  375. idByte := res[0][table.AutoIncrement]
  376. id, err := strconv.ParseInt(string(idByte), 10, 64)
  377. if err != nil || id <= 0 {
  378. return 1, err
  379. }
  380. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  381. if err != nil {
  382. session.Engine.logger.Error(err)
  383. }
  384. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  385. return 1, nil
  386. }
  387. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  388. return 1, nil
  389. } else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
  390. //assert table.AutoIncrement != ""
  391. sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
  392. res, err := session.query(sqlStr, args...)
  393. if err != nil {
  394. return 0, err
  395. }
  396. handleAfterInsertProcessorFunc(bean)
  397. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  398. session.cacheInsert(session.Statement.TableName())
  399. }
  400. if table.Version != "" && session.Statement.checkVersion {
  401. verValue, err := table.VersionColumn().ValueOf(bean)
  402. if err != nil {
  403. session.Engine.logger.Error(err)
  404. } else if verValue.IsValid() && verValue.CanSet() {
  405. verValue.SetInt(1)
  406. }
  407. }
  408. if len(res) < 1 {
  409. return 0, errors.New("insert no error but not returned id")
  410. }
  411. idByte := res[0][table.AutoIncrement]
  412. id, err := strconv.ParseInt(string(idByte), 10, 64)
  413. if err != nil || id <= 0 {
  414. return 1, err
  415. }
  416. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  417. if err != nil {
  418. session.Engine.logger.Error(err)
  419. }
  420. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  421. return 1, nil
  422. }
  423. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  424. return 1, nil
  425. } else {
  426. res, err := session.exec(sqlStr, args...)
  427. if err != nil {
  428. return 0, err
  429. }
  430. defer handleAfterInsertProcessorFunc(bean)
  431. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  432. session.cacheInsert(session.Statement.TableName())
  433. }
  434. if table.Version != "" && session.Statement.checkVersion {
  435. verValue, err := table.VersionColumn().ValueOf(bean)
  436. if err != nil {
  437. session.Engine.logger.Error(err)
  438. } else if verValue.IsValid() && verValue.CanSet() {
  439. verValue.SetInt(1)
  440. }
  441. }
  442. if table.AutoIncrement == "" {
  443. return res.RowsAffected()
  444. }
  445. var id int64
  446. id, err = res.LastInsertId()
  447. if err != nil || id <= 0 {
  448. return res.RowsAffected()
  449. }
  450. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  451. if err != nil {
  452. session.Engine.logger.Error(err)
  453. }
  454. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  455. return res.RowsAffected()
  456. }
  457. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  458. return res.RowsAffected()
  459. }
  460. }
  461. // InsertOne insert only one struct into database as a record.
  462. // The in parameter bean must a struct or a point to struct. The return
  463. // parameter is inserted and error
  464. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  465. defer session.resetStatement()
  466. if session.IsAutoClose {
  467. defer session.Close()
  468. }
  469. return session.innerInsert(bean)
  470. }
  471. func (session *Session) cacheInsert(tables ...string) error {
  472. if session.Statement.RefTable == nil {
  473. return ErrCacheFailed
  474. }
  475. table := session.Statement.RefTable
  476. cacher := session.Engine.getCacher2(table)
  477. for _, t := range tables {
  478. session.Engine.logger.Debug("[cache] clear sql:", t)
  479. cacher.ClearIds(t)
  480. }
  481. return nil
  482. }