feat(billing): 完善模型定价系统支持分段计费和不同输出模式

- 新增 minTokens、maxTokens 和 outputMode 字段到 ModelPrice 实体
- 实现基于 token 区间和输出模式的精细化计费逻辑
- 添加 queryByModelNameAndOutputModeAndTokens 方法支持动态定价查询
- 在 AccountFrozenServiceImpl 中实现 Qwen 3.5 Plus 模型名映射
- 优化账户冻结服务中的 token 费用计算流程
- 更新 AccountService 中的余额检查和交易记录逻辑
This commit is contained in:
wangzhiwei 2026-04-22 17:04:27 +08:00
parent eef2b68291
commit fc7ba98d6e
8 changed files with 144 additions and 35 deletions

View File

@ -44,9 +44,18 @@ public class ModelPrice extends BaseEntity implements Serializable {
@Schema(description ="价格单位")
private String unit;
@Schema(description ="备注")
@Schema(description ="备注/版本信息")
private String remark;
@Schema(description ="计费区间下限(不包含)")
private Long minTokens;
@Schema(description ="计费区间上限(包含,-1代表无穷大")
private Long maxTokens;
@Schema(description ="输出模式standard=非思考模式, thinking=思考模式")
private String outputMode;
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
@Schema(description ="创建时间")
private Date createdTime;

View File

@ -46,4 +46,7 @@ public class TokenConsumptionDto {
@Schema(description ="备注")
private String remark;
@Schema(description ="输出模式standard=非思考模式, thinking=思考模式")
private String outputMode;
}

View File

@ -3,6 +3,7 @@ package com.kexue.skills.mapper;
import com.kexue.skills.entity.ModelPrice;
import com.kexue.skills.entity.dto.ModelPriceDto;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@ -47,6 +48,16 @@ public interface ModelPriceMapper {
*/
ModelPrice queryByModelName(String modelName);
/**
* 根据模型名称输出模式和token数量查询价格规则
*
* @param modelName 模型名称
* @param outputMode 输出模式
* @param tokens token数量
* @return 实例对象
*/
ModelPrice queryByModelNameAndOutputModeAndTokens(@Param("modelName") String modelName, @Param("outputMode") String outputMode, @Param("tokens") Long tokens);
/**
* 新增数据
*

View File

@ -46,6 +46,16 @@ public interface ModelPriceService extends BaseService {
*/
ModelPrice queryByModelName(String modelName);
/**
* 根据模型名称输出模式和token数量查询价格规则
*
* @param modelName 模型名称
* @param outputMode 输出模式
* @param tokens token数量
* @return 实例对象
*/
ModelPrice queryByModelNameAndOutputModeAndTokens(String modelName, String outputMode, Long tokens);
/**
* 新增数据
*

View File

@ -1,10 +1,12 @@
package com.kexue.skills.service.impl;
import com.kexue.skills.common.Assert;
import com.kexue.skills.entity.Account;
import com.kexue.skills.entity.AccountFrozen;
import com.kexue.skills.entity.SysUser;
import com.kexue.skills.entity.dto.AccountFrozenDto;
import com.kexue.skills.entity.dto.AccountReleaseDto;
import com.kexue.skills.entity.dto.ModelPriceDto;
import com.kexue.skills.exception.BizException;
import com.kexue.skills.mapper.AccountFrozenMapper;
import com.kexue.skills.mapper.AccountMapper;
@ -23,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional;
import javax.annotation.Resource;
import java.math.BigDecimal;
import java.util.Date;
import java.util.List;
import java.util.Objects;
/**
@ -91,6 +94,10 @@ public class AccountFrozenServiceImpl implements AccountFrozenService {
accountFrozenDto.getEstimatedOutputTokens() != null &&
accountFrozenDto.getModelName() != null) {
if (accountFrozenDto.getModelName().equals("Qwen 3.5 Plus")) {
accountFrozenDto.setModelName("qwen3.5-plus");
}
// 查询模型价格信息
ModelPrice modelPrice = modelPriceService.queryByModelName(accountFrozenDto.getModelName());
if (modelPrice != null) {
@ -193,27 +200,50 @@ public class AccountFrozenServiceImpl implements AccountFrozenService {
accountReleaseDto.getUsageOutputTokens() != null &&
accountFrozen.getModelName() != null) {
// 查询模型价格信息
ModelPrice modelPrice = modelPriceService.queryByModelName(accountFrozen.getModelName());
if (modelPrice != null) {
// 计算token费用
long inputFee = accountReleaseDto.getUsageInputTokens() / modelPrice.getInputPerCent();
if (accountReleaseDto.getUsageInputTokens() % modelPrice.getInputPerCent() > 0) {
inputFee += 1;
}
// 查询模型价格信息一次性查询所有价格规则减少数据库IO
ModelPriceDto modelPriceDto = new ModelPriceDto();
modelPriceDto.setModelName(accountFrozen.getModelName());
List<ModelPrice> modelPriceList = modelPriceService.getList(modelPriceDto);
long outputFee = accountReleaseDto.getUsageOutputTokens() / modelPrice.getOutputPerCent();
if (accountReleaseDto.getUsageOutputTokens() % modelPrice.getOutputPerCent() > 0) {
outputFee += 1;
}
if (!modelPriceList.isEmpty()) {
// 过滤输入token的价格规则使用standard模式
String inputOutputMode = "standard";
ModelPrice inputModelPrice = modelPriceList.stream()
.filter(mp -> inputOutputMode.equals(mp.getOutputMode()))
.filter(mp -> mp.getMinTokens() < accountReleaseDto.getUsageInputTokens())
.filter(mp -> mp.getMaxTokens() == -1 || mp.getMaxTokens() >= accountReleaseDto.getUsageInputTokens())
.max((mp1, mp2) -> mp1.getMinTokens().compareTo(mp2.getMinTokens()))
.orElse(null);
// 总费用
// 注意因为1分=1积分所以totalFee直接就是积分数量
long totalFee = inputFee + outputFee;
// 转换为积分1分=1积分无需转换
BigDecimal baseAmount = BigDecimal.valueOf(totalFee);
// 应用扣费系数
finalAmount = baseAmount.multiply(accountDeductionProperties.getCoefficient());
// 过滤输出token的价格规则使用thinking模式
String outputOutputMode = "thinking";
ModelPrice outputModelPrice = modelPriceList.stream()
.filter(mp -> outputOutputMode.equals(mp.getOutputMode()))
.filter(mp -> mp.getMinTokens() < accountReleaseDto.getUsageOutputTokens())
.filter(mp -> mp.getMaxTokens() == -1 || mp.getMaxTokens() >= accountReleaseDto.getUsageOutputTokens())
.max((mp1, mp2) -> mp1.getMinTokens().compareTo(mp2.getMinTokens()))
.orElse(null);
if (inputModelPrice != null && outputModelPrice != null) {
// 计算token费用
long inputFee = accountReleaseDto.getUsageInputTokens() / inputModelPrice.getInputPerCent();
if (accountReleaseDto.getUsageInputTokens() % inputModelPrice.getInputPerCent() > 0) {
inputFee += 1;
}
long outputFee = accountReleaseDto.getUsageOutputTokens() / outputModelPrice.getOutputPerCent();
if (accountReleaseDto.getUsageOutputTokens() % outputModelPrice.getOutputPerCent() > 0) {
outputFee += 1;
}
// 总费用
// 注意因为1分=1积分所以totalFee直接就是积分数量
long totalFee = inputFee + outputFee;
// 转换为积分1分=1积分无需转换
BigDecimal baseAmount = BigDecimal.valueOf(totalFee);
// 应用扣费系数
finalAmount = baseAmount.multiply(accountDeductionProperties.getCoefficient());
}
}
}
}

View File

@ -313,20 +313,26 @@ public class AccountServiceImpl implements AccountService {
Account account = queryByUserId(userId);
Assert.notNull(account, "账户不存在");
// 2. 查询模型价格信息
ModelPrice modelPrice = modelPriceService.queryByModelName(dto.getModelName());
Assert.notNull(modelPrice, "模型价格信息不存在");
// 2. 查询输入token的价格规则输入token使用默认的standard模式
String inputOutputMode = "standard";
ModelPrice inputModelPrice = modelPriceService.queryByModelNameAndOutputModeAndTokens(dto.getModelName(), inputOutputMode, Long.valueOf(dto.getInputToken()));
Assert.notNull(inputModelPrice, "输入token价格信息不存在");
// 3. 计算金额
// 3. 查询输出token的价格规则
String outputMode = dto.getOutputMode() != null ? dto.getOutputMode() : "standard";
ModelPrice outputModelPrice = modelPriceService.queryByModelNameAndOutputModeAndTokens(dto.getModelName(), outputMode, Long.valueOf(dto.getOutputToken()));
Assert.notNull(outputModelPrice, "输出token价格信息不存在");
// 4. 计算金额
// 输入token费用输入token数量 / inputPerCent不足1分按1分计算
long inputFee = dto.getInputToken() / modelPrice.getInputPerCent();
if (dto.getInputToken() % modelPrice.getInputPerCent() > 0) {
long inputFee = dto.getInputToken() / inputModelPrice.getInputPerCent();
if (dto.getInputToken() % inputModelPrice.getInputPerCent() > 0) {
inputFee += 1;
}
// 输出token费用输出token数量 / outputPerCent不足1分按1分计算
long outputFee = dto.getOutputToken() / modelPrice.getOutputPerCent();
if (dto.getOutputToken() % modelPrice.getOutputPerCent() > 0) {
long outputFee = dto.getOutputToken() / outputModelPrice.getOutputPerCent();
if (dto.getOutputToken() % outputModelPrice.getOutputPerCent() > 0) {
outputFee += 1;
}
@ -335,11 +341,11 @@ public class AccountServiceImpl implements AccountService {
// 转换为元
BigDecimal amount = BigDecimal.valueOf(totalFee).divide(BigDecimal.valueOf(100));
// 4. 检查余额是否足够
// 5. 检查余额是否足够
BigDecimal balance = account.getBalance() == null ? BigDecimal.ZERO : account.getBalance();
Assert.isTrue(balance.compareTo(amount) >= 0, "账户余额不足");
// 5. 保存交易记录
// 6. 保存交易记录
AccountTransaction transaction = new AccountTransaction();
transaction.setUserId(userId);
transaction.setUserName(account.getUserName());
@ -359,7 +365,7 @@ public class AccountServiceImpl implements AccountService {
transaction.setQuestion(dto.getQuestion());
accountTransactionMapper.insert(transaction);
// 6. 更新账户余额
// 7. 更新账户余额
accountMapper.updateBalance(userId, amount, 2);
return amount;
}

View File

@ -70,6 +70,19 @@ public class ModelPriceServiceImpl implements ModelPriceService {
return this.modelPriceMapper.queryByModelName(modelName);
}
/**
* 根据模型名称输出模式和token数量查询价格规则
*
* @param modelName 模型名称
* @param outputMode 输出模式
* @param tokens token数量
* @return 实例对象
*/
@Override
public ModelPrice queryByModelNameAndOutputModeAndTokens(String modelName, String outputMode, Long tokens) {
return this.modelPriceMapper.queryByModelNameAndOutputModeAndTokens(modelName, outputMode, tokens);
}
/**
* 新增数据
*

View File

@ -12,6 +12,9 @@
<result property="outputPerCent" column="output_per_cent" jdbcType="BIGINT"/>
<result property="unit" column="unit" jdbcType="VARCHAR"/>
<result property="remark" column="remark" jdbcType="VARCHAR"/>
<result property="minTokens" column="min_tokens" jdbcType="BIGINT"/>
<result property="maxTokens" column="max_tokens" jdbcType="BIGINT"/>
<result property="outputMode" column="output_mode" jdbcType="VARCHAR"/>
<result property="createdTime" column="created_time" jdbcType="TIMESTAMP"/>
<result property="updatedTime" column="updated_time" jdbcType="TIMESTAMP"/>
</resultMap>
@ -19,7 +22,7 @@
<!--查询单个-->
<select id="queryById" resultMap="ModelPriceMap">
select
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, created_time, updated_time
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, min_tokens, max_tokens, output_mode, created_time, updated_time
from model_price
where id = #{id}
</select>
@ -27,7 +30,7 @@
<!--通过模型名称查询-->
<select id="queryByModelName" resultMap="ModelPriceMap">
select
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, created_time, updated_time
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, min_tokens, max_tokens, output_mode, created_time, updated_time
from model_price
where model_name = #{modelName}
</select>
@ -35,7 +38,7 @@
<!--分页查询-->
<select id="getPageList" resultMap="ModelPriceMap">
select
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, created_time, updated_time
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, min_tokens, max_tokens, output_mode, created_time, updated_time
from model_price
<where>
<if test="vendor != null and vendor != ''">
@ -53,7 +56,7 @@
<!--查询列表-->
<select id="getList" resultMap="ModelPriceMap">
select
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, created_time, updated_time
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, min_tokens, max_tokens, output_mode, created_time, updated_time
from model_price
<where>
<if test="vendor != null and vendor != ''">
@ -77,6 +80,9 @@
<if test="outputPerCent != null">output_per_cent,</if>
<if test="unit != null">unit,</if>
<if test="remark != null">remark,</if>
<if test="minTokens != null">min_tokens,</if>
<if test="maxTokens != null">max_tokens,</if>
<if test="outputMode != null">output_mode,</if>
<if test="createdTime != null">created_time,</if>
<if test="updatedTime != null">updated_time,</if>
</trim>
@ -89,6 +95,9 @@
<if test="outputPerCent != null">#{outputPerCent},</if>
<if test="unit != null">#{unit},</if>
<if test="remark != null">#{remark},</if>
<if test="minTokens != null">#{minTokens},</if>
<if test="maxTokens != null">#{maxTokens},</if>
<if test="outputMode != null">#{outputMode},</if>
<if test="createdTime != null">#{createdTime},</if>
<if test="updatedTime != null">#{updatedTime},</if>
</trim>
@ -106,6 +115,9 @@
<if test="outputPerCent != null">output_per_cent = #{outputPerCent},</if>
<if test="unit != null">unit = #{unit},</if>
<if test="remark != null">remark = #{remark},</if>
<if test="minTokens != null">min_tokens = #{minTokens},</if>
<if test="maxTokens != null">max_tokens = #{maxTokens},</if>
<if test="outputMode != null">output_mode = #{outputMode},</if>
<if test="createdTime != null">created_time = #{createdTime},</if>
<if test="updatedTime != null">updated_time = #{updatedTime},</if>
</set>
@ -118,4 +130,19 @@
where id = #{id}
</delete>
<!--根据模型名称、输出模式和token数量查询价格规则-->
<select id="queryByModelNameAndOutputModeAndTokens" resultMap="ModelPriceMap">
<![CDATA[
select
id, vendor, model_name, input_price, output_price, input_per_cent, output_per_cent, unit, remark, min_tokens, max_tokens, output_mode, created_time, updated_time
from model_price
where model_name = #{modelName}
and output_mode = #{outputMode}
and min_tokens < #{tokens}
and (max_tokens = -1 or max_tokens >= #{tokens})
order by min_tokens desc
limit 1
]]>
</select>
</mapper>