From fc7ba98d6e29ac9fedc740c441271756f9047ecf Mon Sep 17 00:00:00 2001 From: wangzhiwei Date: Wed, 22 Apr 2026 17:04:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(billing):=20=E5=AE=8C=E5=96=84=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=AE=9A=E4=BB=B7=E7=B3=BB=E7=BB=9F=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=88=86=E6=AE=B5=E8=AE=A1=E8=B4=B9=E5=92=8C=E4=B8=8D=E5=90=8C?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 minTokens、maxTokens 和 outputMode 字段到 ModelPrice 实体 - 实现基于 token 区间和输出模式的精细化计费逻辑 - 添加 queryByModelNameAndOutputModeAndTokens 方法支持动态定价查询 - 在 AccountFrozenServiceImpl 中实现 Qwen 3.5 Plus 模型名映射 - 优化账户冻结服务中的 token 费用计算流程 - 更新 AccountService 中的余额检查和交易记录逻辑 --- .../com/kexue/skills/entity/ModelPrice.java | 11 ++- .../entity/dto/TokenConsumptionDto.java | 3 + .../kexue/skills/mapper/ModelPriceMapper.java | 11 +++ .../skills/service/ModelPriceService.java | 10 +++ .../impl/AccountFrozenServiceImpl.java | 68 +++++++++++++------ .../service/impl/AccountServiceImpl.java | 28 +++++--- .../service/impl/ModelPriceServiceImpl.java | 13 ++++ .../resources/mapper/ModelPriceMapper.xml | 35 ++++++++-- 8 files changed, 144 insertions(+), 35 deletions(-) diff --git a/src/main/java/com/kexue/skills/entity/ModelPrice.java b/src/main/java/com/kexue/skills/entity/ModelPrice.java index 6e90ef6..ec2bd12 100644 --- a/src/main/java/com/kexue/skills/entity/ModelPrice.java +++ b/src/main/java/com/kexue/skills/entity/ModelPrice.java @@ -44,9 +44,18 @@ public class ModelPrice extends BaseEntity implements Serializable { @Schema(description ="价格单位") private String unit; - @Schema(description ="备注") + @Schema(description ="备注/版本信息") private String remark; + @Schema(description ="计费区间下限(不包含)") + private Long minTokens; + + @Schema(description ="计费区间上限(包含,-1代表无穷大)") + private Long maxTokens; + + @Schema(description ="输出模式:standard=非思考模式, thinking=思考模式") + private String outputMode; + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8") @Schema(description ="创建时间") private Date createdTime; diff --git a/src/main/java/com/kexue/skills/entity/dto/TokenConsumptionDto.java b/src/main/java/com/kexue/skills/entity/dto/TokenConsumptionDto.java index 9a51678..8c16a97 100644 --- a/src/main/java/com/kexue/skills/entity/dto/TokenConsumptionDto.java +++ b/src/main/java/com/kexue/skills/entity/dto/TokenConsumptionDto.java @@ -46,4 +46,7 @@ public class TokenConsumptionDto { @Schema(description ="备注") private String remark; + @Schema(description ="输出模式:standard=非思考模式, thinking=思考模式") + private String outputMode; + } diff --git a/src/main/java/com/kexue/skills/mapper/ModelPriceMapper.java b/src/main/java/com/kexue/skills/mapper/ModelPriceMapper.java index 2e74058..f80baea 100644 --- a/src/main/java/com/kexue/skills/mapper/ModelPriceMapper.java +++ b/src/main/java/com/kexue/skills/mapper/ModelPriceMapper.java @@ -3,6 +3,7 @@ package com.kexue.skills.mapper; import com.kexue.skills.entity.ModelPrice; import com.kexue.skills.entity.dto.ModelPriceDto; import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; import java.util.List; @@ -47,6 +48,16 @@ public interface ModelPriceMapper { */ ModelPrice queryByModelName(String modelName); + /** + * 根据模型名称、输出模式和token数量查询价格规则 + * + * @param modelName 模型名称 + * @param outputMode 输出模式 + * @param tokens token数量 + * @return 实例对象 + */ + ModelPrice queryByModelNameAndOutputModeAndTokens(@Param("modelName") String modelName, @Param("outputMode") String outputMode, @Param("tokens") Long tokens); + /** * 新增数据 * diff --git a/src/main/java/com/kexue/skills/service/ModelPriceService.java b/src/main/java/com/kexue/skills/service/ModelPriceService.java index 9808881..ffa26e4 100644 --- a/src/main/java/com/kexue/skills/service/ModelPriceService.java +++ b/src/main/java/com/kexue/skills/service/ModelPriceService.java @@ -46,6 +46,16 @@ public interface ModelPriceService extends BaseService { */ ModelPrice queryByModelName(String modelName); + /** + * 根据模型名称、输出模式和token数量查询价格规则 + * + * @param modelName 模型名称 + * @param outputMode 输出模式 + * @param tokens token数量 + * @return 实例对象 + */ + ModelPrice queryByModelNameAndOutputModeAndTokens(String modelName, String outputMode, Long tokens); + /** * 新增数据 * diff --git a/src/main/java/com/kexue/skills/service/impl/AccountFrozenServiceImpl.java b/src/main/java/com/kexue/skills/service/impl/AccountFrozenServiceImpl.java index 93103cc..146382f 100644 --- a/src/main/java/com/kexue/skills/service/impl/AccountFrozenServiceImpl.java +++ b/src/main/java/com/kexue/skills/service/impl/AccountFrozenServiceImpl.java @@ -1,10 +1,12 @@ package com.kexue.skills.service.impl; +import com.kexue.skills.common.Assert; import com.kexue.skills.entity.Account; import com.kexue.skills.entity.AccountFrozen; import com.kexue.skills.entity.SysUser; import com.kexue.skills.entity.dto.AccountFrozenDto; import com.kexue.skills.entity.dto.AccountReleaseDto; +import com.kexue.skills.entity.dto.ModelPriceDto; import com.kexue.skills.exception.BizException; import com.kexue.skills.mapper.AccountFrozenMapper; import com.kexue.skills.mapper.AccountMapper; @@ -23,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional; import javax.annotation.Resource; import java.math.BigDecimal; import java.util.Date; +import java.util.List; import java.util.Objects; /** @@ -90,6 +93,10 @@ public class AccountFrozenServiceImpl implements AccountFrozenService { if (accountFrozenDto.getEstimatedInputTokens() != null && accountFrozenDto.getEstimatedOutputTokens() != null && accountFrozenDto.getModelName() != null) { + + if (accountFrozenDto.getModelName().equals("Qwen 3.5 Plus")) { + accountFrozenDto.setModelName("qwen3.5-plus"); + } // 查询模型价格信息 ModelPrice modelPrice = modelPriceService.queryByModelName(accountFrozenDto.getModelName()); @@ -193,27 +200,50 @@ public class AccountFrozenServiceImpl implements AccountFrozenService { accountReleaseDto.getUsageOutputTokens() != null && accountFrozen.getModelName() != null) { - // 查询模型价格信息 - ModelPrice modelPrice = modelPriceService.queryByModelName(accountFrozen.getModelName()); - if (modelPrice != null) { - // 计算token费用 - long inputFee = accountReleaseDto.getUsageInputTokens() / modelPrice.getInputPerCent(); - if (accountReleaseDto.getUsageInputTokens() % modelPrice.getInputPerCent() > 0) { - inputFee += 1; - } + // 查询模型价格信息(一次性查询所有价格规则,减少数据库IO) + ModelPriceDto modelPriceDto = new ModelPriceDto(); + modelPriceDto.setModelName(accountFrozen.getModelName()); + List modelPriceList = modelPriceService.getList(modelPriceDto); + + if (!modelPriceList.isEmpty()) { + // 过滤输入token的价格规则(使用standard模式) + String inputOutputMode = "standard"; + ModelPrice inputModelPrice = modelPriceList.stream() + .filter(mp -> inputOutputMode.equals(mp.getOutputMode())) + .filter(mp -> mp.getMinTokens() < accountReleaseDto.getUsageInputTokens()) + .filter(mp -> mp.getMaxTokens() == -1 || mp.getMaxTokens() >= accountReleaseDto.getUsageInputTokens()) + .max((mp1, mp2) -> mp1.getMinTokens().compareTo(mp2.getMinTokens())) + .orElse(null); + + // 过滤输出token的价格规则(使用thinking模式) + String outputOutputMode = "thinking"; + ModelPrice outputModelPrice = modelPriceList.stream() + .filter(mp -> outputOutputMode.equals(mp.getOutputMode())) + .filter(mp -> mp.getMinTokens() < accountReleaseDto.getUsageOutputTokens()) + .filter(mp -> mp.getMaxTokens() == -1 || mp.getMaxTokens() >= accountReleaseDto.getUsageOutputTokens()) + .max((mp1, mp2) -> mp1.getMinTokens().compareTo(mp2.getMinTokens())) + .orElse(null); + + if (inputModelPrice != null && outputModelPrice != null) { + // 计算token费用 + long inputFee = accountReleaseDto.getUsageInputTokens() / inputModelPrice.getInputPerCent(); + if (accountReleaseDto.getUsageInputTokens() % inputModelPrice.getInputPerCent() > 0) { + inputFee += 1; + } - long outputFee = accountReleaseDto.getUsageOutputTokens() / modelPrice.getOutputPerCent(); - if (accountReleaseDto.getUsageOutputTokens() % modelPrice.getOutputPerCent() > 0) { - outputFee += 1; - } + long outputFee = accountReleaseDto.getUsageOutputTokens() / outputModelPrice.getOutputPerCent(); + if (accountReleaseDto.getUsageOutputTokens() % outputModelPrice.getOutputPerCent() > 0) { + outputFee += 1; + } - // 总费用(分) - // 注意:因为1分=1积分,所以totalFee直接就是积分数量 - long totalFee = inputFee + outputFee; - // 转换为积分(1分=1积分,无需转换) - BigDecimal baseAmount = BigDecimal.valueOf(totalFee); - // 应用扣费系数 - finalAmount = baseAmount.multiply(accountDeductionProperties.getCoefficient()); + // 总费用(分) + // 注意:因为1分=1积分,所以totalFee直接就是积分数量 + long totalFee = inputFee + outputFee; + // 转换为积分(1分=1积分,无需转换) + BigDecimal baseAmount = BigDecimal.valueOf(totalFee); + // 应用扣费系数 + finalAmount = baseAmount.multiply(accountDeductionProperties.getCoefficient()); + } } } } diff --git a/src/main/java/com/kexue/skills/service/impl/AccountServiceImpl.java b/src/main/java/com/kexue/skills/service/impl/AccountServiceImpl.java index 29a035c..32107cd 100644 --- a/src/main/java/com/kexue/skills/service/impl/AccountServiceImpl.java +++ b/src/main/java/com/kexue/skills/service/impl/AccountServiceImpl.java @@ -313,20 +313,26 @@ public class AccountServiceImpl implements AccountService { Account account = queryByUserId(userId); Assert.notNull(account, "账户不存在"); - // 2. 查询模型价格信息 - ModelPrice modelPrice = modelPriceService.queryByModelName(dto.getModelName()); - Assert.notNull(modelPrice, "模型价格信息不存在"); + // 2. 查询输入token的价格规则(输入token使用默认的standard模式) + String inputOutputMode = "standard"; + ModelPrice inputModelPrice = modelPriceService.queryByModelNameAndOutputModeAndTokens(dto.getModelName(), inputOutputMode, Long.valueOf(dto.getInputToken())); + Assert.notNull(inputModelPrice, "输入token价格信息不存在"); - // 3. 计算金额 + // 3. 查询输出token的价格规则 + String outputMode = dto.getOutputMode() != null ? dto.getOutputMode() : "standard"; + ModelPrice outputModelPrice = modelPriceService.queryByModelNameAndOutputModeAndTokens(dto.getModelName(), outputMode, Long.valueOf(dto.getOutputToken())); + Assert.notNull(outputModelPrice, "输出token价格信息不存在"); + + // 4. 计算金额 // 输入token费用:输入token数量 / inputPerCent,不足1分按1分计算 - long inputFee = dto.getInputToken() / modelPrice.getInputPerCent(); - if (dto.getInputToken() % modelPrice.getInputPerCent() > 0) { + long inputFee = dto.getInputToken() / inputModelPrice.getInputPerCent(); + if (dto.getInputToken() % inputModelPrice.getInputPerCent() > 0) { inputFee += 1; } // 输出token费用:输出token数量 / outputPerCent,不足1分按1分计算 - long outputFee = dto.getOutputToken() / modelPrice.getOutputPerCent(); - if (dto.getOutputToken() % modelPrice.getOutputPerCent() > 0) { + long outputFee = dto.getOutputToken() / outputModelPrice.getOutputPerCent(); + if (dto.getOutputToken() % outputModelPrice.getOutputPerCent() > 0) { outputFee += 1; } @@ -335,11 +341,11 @@ public class AccountServiceImpl implements AccountService { // 转换为元 BigDecimal amount = BigDecimal.valueOf(totalFee).divide(BigDecimal.valueOf(100)); - // 4. 检查余额是否足够 + // 5. 检查余额是否足够 BigDecimal balance = account.getBalance() == null ? BigDecimal.ZERO : account.getBalance(); Assert.isTrue(balance.compareTo(amount) >= 0, "账户余额不足"); - // 5. 保存交易记录 + // 6. 保存交易记录 AccountTransaction transaction = new AccountTransaction(); transaction.setUserId(userId); transaction.setUserName(account.getUserName()); @@ -359,7 +365,7 @@ public class AccountServiceImpl implements AccountService { transaction.setQuestion(dto.getQuestion()); accountTransactionMapper.insert(transaction); - // 6. 更新账户余额 + // 7. 更新账户余额 accountMapper.updateBalance(userId, amount, 2); return amount; } diff --git a/src/main/java/com/kexue/skills/service/impl/ModelPriceServiceImpl.java b/src/main/java/com/kexue/skills/service/impl/ModelPriceServiceImpl.java index 057adf4..7e70ef9 100644 --- a/src/main/java/com/kexue/skills/service/impl/ModelPriceServiceImpl.java +++ b/src/main/java/com/kexue/skills/service/impl/ModelPriceServiceImpl.java @@ -70,6 +70,19 @@ public class ModelPriceServiceImpl implements ModelPriceService { return this.modelPriceMapper.queryByModelName(modelName); } + /** + * 根据模型名称、输出模式和token数量查询价格规则 + * + * @param modelName 模型名称 + * @param outputMode 输出模式 + * @param tokens token数量 + * @return 实例对象 + */ + @Override + public ModelPrice queryByModelNameAndOutputModeAndTokens(String modelName, String outputMode, Long tokens) { + return this.modelPriceMapper.queryByModelNameAndOutputModeAndTokens(modelName, outputMode, tokens); + } + /** * 新增数据 * diff --git a/src/main/resources/mapper/ModelPriceMapper.xml b/src/main/resources/mapper/ModelPriceMapper.xml index 73463e7..8550aa7 100644 --- a/src/main/resources/mapper/ModelPriceMapper.xml +++ b/src/main/resources/mapper/ModelPriceMapper.xml @@ -12,6 +12,9 @@ + + + @@ -19,7 +22,7 @@ @@ -27,7 +30,7 @@ @@ -35,7 +38,7 @@ select - id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, created_time, updated_time + id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, min_tokens, max_tokens, output_mode, created_time, updated_time from model_price @@ -77,6 +80,9 @@ output_per_cent, unit, remark, + min_tokens, + max_tokens, + output_mode, created_time, updated_time, @@ -89,6 +95,9 @@ #{outputPerCent}, #{unit}, #{remark}, + #{minTokens}, + #{maxTokens}, + #{outputMode}, #{createdTime}, #{updatedTime}, @@ -106,6 +115,9 @@ output_per_cent = #{outputPerCent}, unit = #{unit}, remark = #{remark}, + min_tokens = #{minTokens}, + max_tokens = #{maxTokens}, + output_mode = #{outputMode}, created_time = #{createdTime}, updated_time = #{updatedTime}, @@ -118,4 +130,19 @@ where id = #{id} + + +