一、前言
经过对qwen-7b-chat的部署以及与vllm的推理加速的整合,我们成功构建了一套高性能、高可靠、高安全的AI服务能力。现在,我们将着手整合具体的业务场景,以实现完整可落地的功能交付。
作为上游部门,通常会采用最常用的方式来接入下游服务。为了调用我们的AI服务,我们将使用Java语言,并分别使用HttpClient、OkHttp等工具来实现调用。这样可以确保我们能够高效地与AI服务进行交互。
二、术语
2.1.OkHttp
是一个开源的Java和Kotlin HTTP客户端库,用于进行网络请求。OkHttp支持HTTP/1.1和HTTP/2协议,具有连接池、请求重试、缓存、拦截器等功能。它还提供了异步和同步请求的支持,并且可以与各种平台和框架无缝集成,是Android开发中常用的网络请求库之一。通过使用OkHttp,开发人员可以轻松地发送HTTP请求、处理响应以及管理网络连接,从而加快应用程序的网络通信速度和效率。
2.2.HttpClient
是一个用于发送HTTP请求和接收HTTP响应的开源库。它提供了一种方便的方式来与Web服务器进行通信,并执行各种HTTP操作,例如发送GET请求、POST请求等。HttpClient库通常用于编写客户端应用程序或服务,这些应用程序需要与Web服务器或Web API进行通信。它提供了许多功能,包括连接管理、身份验证、请求和响应拦截、Cookie管理等。
2.3.HttpURLConnection
是Java提供的一个用于发送HTTP请求和接收HTTP响应的类。它是Java标准库中的一部分,用于与Web服务器进行通信。HttpURLConnection类提供了一组方法,使您能够创建HTTP连接、设置请求方法(如GET、POST等)、设置请求头、设置请求体和其他参数,并发送请求到指定的URL。它还提供了方法来获取HTTP响应的状态码、响应头和响应体等信息。
OkHttp和HttpClient提供了更丰富的功能和更好的性能,适用于大多数情况下。HttpURLConnection是Java标准库中的类,提供了基本的HTTP功能
三、前置条件
3.1. 完成Qwen-7b-Chat(Qwen-1_8B-Chat)模型的本地部署或服务端部署
参见“开源模型应用落地-qwen-7b-chat与vllm实现推理加速的正确姿势”系列文章
3.2. 完成对外服务接口的封装,屏蔽不同模型的调用差异
参见“开源模型应用落地-qwen-7b-chat与vllm实现推理加速的正确姿势”系列文章
四、技术实现
4.1. HttpURLConnection
import lombok.extern.slf4j.Slf4j;import java.io.ByteArrayOutputStream;import java.io.InputStream;import java.io.OutputStream;import java.net.HttpURLConnection;import java.net.URL;import java.nio.charset.StandardCharsets;import java.util.Objects;@Slf4jpublic class QWenCallTest { private static String url = "http://127.0.0.1:9999/api/chat"; private static String DEFAULT_TEMPLATE = "{\"prompt\":\"%s\",\"history\":%x,\"top_p\":0.9, \"temperature\":0.45,\"repetition_penalty\":1.1, \"max_new_tokens\":8192}"; private static String DEFAULT_USERID = "xxxxx"; private static String DEFAULT_SECRET = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; private static int DEFAULT_CONNECTION_TIMEOUT = 3 * 1000; private static int DEFAULT_READ_TIMEOUT = 30 * 1000; public static void main(String[] args) { String question = "我家周边有什么好吃、好玩的地方嘛?"; String history = "[{\n" + "\"from\": \"user\",\n" + "\"value\": \"你好\"\n" + "},\n" + "{\n" + "\"from\": \"assistant\",\n" + "\"value\": \"你好!有什么我能为你效劳的吗?\"\n" + "},\n" + "{\n" + "\"from\": \"user\",\n" + "\"value\": \"我家在广州,你呢?\"\n" + "},\n" + "{\n" + "\"from\": \"assistant\",\n" + "\"value\": \"我是一个人工智能助手,没有具体的家。\"\n" + "}]"; String prompt = DEFAULT_TEMPLATE.replace("%s", question).replace("%x", history); log.info("prompt: {}", prompt); HttpURLConnection conn = null; OutputStream os = null; try { //1.设置URL URL urlObject = new URL(url); //2.打开URL连接 conn = (HttpURLConnection) urlObject.openConnection(); //3.设置请求方式 conn.setRequestMethod("POST"); conn.setRequestProperty("Content-Type", "application/json;charset=utf-8"); conn.setRequestProperty("Accept", "text/event-stream"); conn.setRequestProperty("userId", DEFAULT_USERID); conn.setRequestProperty("secret", DEFAULT_SECRET); conn.setDoOutput(true); conn.setDoInput(true); // 设置连接超时时间为60秒 conn.setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT); // 设置读取超时时间为60秒 conn.setReadTimeout(DEFAULT_READ_TIMEOUT); os = conn.getOutputStream(); os.write(prompt.getBytes("utf-8")); } catch (Exception e) { log.error("请求模型接口异常", e); } finally { if(!Objects.isNull(os)){ try { os.flush(); os.close(); } catch (Exception e) { } } } InputStream is = null; try{ if(!Objects.isNull(conn)){ int responseCode = conn.getResponseCode(); log.info("Response Code: " + responseCode); if(responseCode == 200){ is = conn.getInputStream(); }else{ is = conn.getErrorStream(); } byte[] bytes = new byte[1024]; int len = 0; while ((len = is.read(bytes)) != -1) { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); outputStream.write(bytes, 0, len); String response = new String(outputStream.toByteArray(), StandardCharsets.UTF_8); log.info(response); } } } catch (Exception e) { log.error("请求模型接口异常", e); } finally { if (!Objects.isNull(is)) { try { is.close(); } catch (Exception e) { e.printStackTrace(); } } } }}
4.2. OkHttp
import com.alibaba.fastjson.JSON;import lombok.extern.slf4j.Slf4j;import okhttp3.*;import java.io.ByteArrayOutputStream;import java.io.InputStream;import java.nio.charset.StandardCharsets;import java.util.Objects;import java.util.concurrent.CountDownLatch;import java.util.concurrent.TimeUnit;@Slf4jpublic class QWenCallTest { private static String url = "http://127.0.0.1:9999/api/chat"; private static String DEFAULT_TEMPLATE = "{\"prompt\":\"%s\",\"history\":%x,\"top_p\":0.9, \"temperature\":0.45,\"repetition_penalty\":1.2, \"max_new_tokens\":8192}"; private static long DEFAULT_CONNECTION_TIMEOUT = 3 * 1000; private static long DEFAULT_WRITE_TIMEOUT = 15 * 1000; private static long DEFAULT_READ_TIMEOUT = 15 * 1000; private final static Request.Builder buildHeader(Request.Builder builder) { return builder .addHeader("Content-Type", "application/json; charset=utf-8") .addHeader("userId", "xxxxx") .addHeader("secret", "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); } private final static Request buildRequest(String prompt) { //创建一个请求体对象(body) MediaType mediaType = MediaType.parse("application/json"); RequestBody requestBody = RequestBody.create(mediaType,prompt); return buildHeader(new Request.Builder().post(requestBody)) .url(url).build(); } public static void chat(String question,String history,CountDownLatch countDownLatch) { //定义请求的参数 String prompt = DEFAULT_TEMPLATE.replace("%s", question).replace("%x", history); log.info("prompt: {}", prompt); //创建一个请求对象 Request request = buildRequest(prompt); //发送请求:创建了一个请求工具对象,调用执行request对象 OkHttpClient okHttpClient = new OkHttpClient().newBuilder() .connectTimeout(DEFAULT_CONNECTION_TIMEOUT, TimeUnit.MILLISECONDS) .writeTimeout(DEFAULT_WRITE_TIMEOUT, TimeUnit.MILLISECONDS) .readTimeout(DEFAULT_READ_TIMEOUT, TimeUnit.MILLISECONDS) .build(); InputStream is = null; try { Response response = okHttpClient.newCall(request).execute(); //正常返回 if(response.code() == 200){ //打印返回的字符数据 is = response.body().byteStream(); byte[] bytes = new byte[1024]; int len = 0; while ((len = is.read(bytes)) != -1) { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); outputStream.write(bytes, 0, len); outputStream.flush(); String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8); log.info(result); } } else{ String result = response.body().string(); String jsonStr = JSON.parseObject(result).toJSONString(); log.info(jsonStr); } } catch (Throwable e) { log.error("执行异常", e); } finally { if (!Objects.isNull(is)) { try { is.close(); } catch (Exception e) { e.printStackTrace(); } } countDownLatch.countDown(); } } public static void main(String[] args) { CountDownLatch countDownLatch = new CountDownLatch(1); String question = "我家周边有什么好吃、好玩的地方嘛?"; String history = "[{\n" + "\"from\": \"user\",\n" + "\"value\": \"你好\"\n" + "},\n" + "{\n" + "\"from\": \"assistant\",\n" + "\"value\": \"你好!有什么我能为你效劳的吗?\"\n" + "},\n" + "{\n" + "\"from\": \"user\",\n" + "\"value\": \"我家在广州,你呢?\"\n" + "},\n" + "{\n" + "\"from\": \"assistant\",\n" + "\"value\": \"我是一个人工智能助手,没有具体的家。\"\n" + "}]"; //流式输出 long starttime = System.currentTimeMillis(); chat(question,history,countDownLatch); long endtime = System.currentTimeMillis(); System.err.println((endtime-starttime)); try { countDownLatch.await(); } catch (InterruptedException e) { e.printStackTrace(); } }}
maven依赖
<dependency> <groupId>com.squareup.okhttp3</groupId> <artifactId>okhttp</artifactId> <version>3.14.9</version></dependency>
4.3. HttpClient
import lombok.extern.slf4j.Slf4j;import org.asynchttpclient.AsyncHttpClient;import org.asynchttpclient.AsyncHttpClientConfig;import org.asynchttpclient.DefaultAsyncHttpClient;import org.asynchttpclient.DefaultAsyncHttpClientConfig;import org.asynchttpclient.channel.DefaultKeepAliveStrategy;import java.io.IOException;import java.util.concurrent.CountDownLatch;import java.util.concurrent.TimeUnit;@Slf4jpublic class QwenCallTest { private static String url = "http://127.0.0.1:9999/api/chat"; private static String DEFAULT_TEMPLATE = "{\"prompt\":\"%s\",\"history\":%x,\"top_p\":0.9, \"temperature\":0.45,\"repetition_penalty\":1.1, \"max_new_tokens\":8192}"; private static int DEFAULT_CONNECTION_TIMEOUT = 3 * 1000; private static int DEFAULT_REQUEST_TIMEOUT = 15* 1000; private static int DEFAULT_READ_TIMEOUT = 15* 1000; private static String DEFAULT_USERID = "xxxxx"; private static String DEFAULT_SECRET = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; private static AsyncHttpClientConfig asyncHttpClientConfig = new DefaultAsyncHttpClientConfig.Builder() .setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT) .setReadTimeout(DEFAULT_READ_TIMEOUT) .setRequestTimeout(DEFAULT_REQUEST_TIMEOUT) .setTcpNoDelay(true) .setMaxConnections(1_000_000) .setMaxConnectionsPerHost(100_000) .setMaxRequestRetry(0) .setSoReuseAddress(true) .setKeepAlive(true) .setKeepAliveStrategy(new DefaultKeepAliveStrategy()) .build(); public static final AsyncHttpClient ugc_asyncHttpClient = new DefaultAsyncHttpClient(asyncHttpClientConfig); public static void chat(String question,String history, CountDownLatch countDownLatch ) { String prompt = DEFAULT_TEMPLATE.replace("%s", question).replace("%x", history); log.info("prompt: {}", prompt); try { ugc_asyncHttpClient.preparePost(url) .addHeader("Content-Type", "application/json; charset=utf-8") .addHeader("userId", DEFAULT_USERID) .addHeader("secret", DEFAULT_SECRET) .addHeader("Accept", "text/event-stream") .setBody(prompt) .execute(new QwenStreamHandler(countDownLatch)) .get(30, TimeUnit.SECONDS); } catch (Exception e) { log.error(prompt + " >> 出现异常"); } } public static void main(String[] args) { CountDownLatch countDownLatch = new CountDownLatch(1); String question = "我家周边有什么好吃、好玩的地方嘛?"; String history = "[{\n" + "\"from\": \"user\",\n" + "\"value\": \"你好\"\n" + "},\n" + "{\n" + "\"from\": \"assistant\",\n" + "\"value\": \"你好!有什么我能为你效劳的吗?\"\n" + "},\n" + "{\n" + "\"from\": \"user\",\n" + "\"value\": \"我家在广州,你呢?\"\n" + "},\n" + "{\n" + "\"from\": \"assistant\",\n" + "\"value\": \"我是一个人工智能助手,没有具体的家。\"\n" + "}]"; chat(question,history,countDownLatch); try { countDownLatch.await(); } catch (InterruptedException e) { e.printStackTrace(); } try { ugc_asyncHttpClient.close(); } catch (IOException e) { e.printStackTrace(); } }}
import io.netty.handler.codec.http.HttpHeaders;import lombok.extern.slf4j.Slf4j;import org.asynchttpclient.HttpResponseBodyPart;import org.asynchttpclient.HttpResponseStatus;import org.asynchttpclient.handler.StreamedAsyncHandler;import org.asynchttpclient.netty.EagerResponseBodyPart;import org.reactivestreams.Publisher;import org.reactivestreams.Subscriber;import org.reactivestreams.Subscription;import java.util.concurrent.CountDownLatch;@Slf4jpublic class QwenStreamHandler implements StreamedAsyncHandler<String> { private CountDownLatch countDownLatch; public QwenStreamHandler(CountDownLatch countDownLatch){ this.countDownLatch = countDownLatch; } @Override public State onStream(Publisher publisher) { publisher.subscribe(new Subscriber() { @Override public void onSubscribe(Subscription subscription) { subscription.request(Long.MAX_VALUE); } @Override public void onNext(Object obj) { try{ if(obj instanceof EagerResponseBodyPart){ EagerResponseBodyPart part = (EagerResponseBodyPart)obj; byte[] bytes = part.getBodyPartBytes(); String words = new String(bytes,"UTF-8"); log.info(words); } }catch(Throwable e){ log.error("系统异常",e); } } @Override public void onError(Throwable throwable) { log.error("系统异常",throwable); } @Override public void onComplete() { countDownLatch.countDown(); } }); return State.CONTINUE; } @Override public State onStatusReceived(HttpResponseStatus responseStatus) throws Exception { log.info("onStatusReceived: {}",responseStatus.getStatusCode()); return responseStatus.getStatusCode() == 200 ? State.CONTINUE : State.ABORT; } @Override public State onHeadersReceived(HttpHeaders headers) throws Exception { return State.CONTINUE; } @Override public State onBodyPartReceived(HttpResponseBodyPart bodyPart) throws Exception { return State.CONTINUE; } @Override public void onThrowable(Throwable t) { log.error("onThrowable", t); } @Override public String onCompleted() throws Exception { return State.ABORT.name(); }}
maven依赖
<dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>4.5.12</version></dependency><dependency> <groupId>org.asynchttpclient</groupId> <artifactId>async-http-client</artifactId> <version>2.12.3</version></dependency>